提交 c7b188c4 编写于 作者: R ranqiu

Fix bugs of external_memory.py and infer.py

上级 f456031b
......@@ -113,7 +113,7 @@ class ExternalMemory(object):
boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight],
input=[last_addressing_weight, addressing_weight],
weight=paddle.layer.expand(input=gate, expand_as=addressing_weight))
return interpolated_weight
......
......@@ -4,6 +4,7 @@
import distutils.util
import argparse
import gzip
import random
import paddle.v2 as paddle
from external_memory import ExternalMemory
......@@ -118,10 +119,11 @@ def infer():
infer_data = []
random.seed(0) # for keeping consitancy for multiple runs
bounded_memory_perturbation = [[
random.gauss(0, memory_perturb_stddev) for i in xrange(args.hidden_size)
random.gauss(0, args.memory_perturb_stddev)
for i in xrange(args.hidden_size)
] for j in xrange(args.memory_slot_num)]
test_append_reader = reader_append_wrapper(
reader=paddle.dataset.wmt14.test(dict_size),
reader=paddle.dataset.wmt14.test(args.dict_size),
append_tuple=(bounded_memory_perturbation, ))
for i, item in enumerate(test_append_reader()):
if i < args.infer_data_num:
......@@ -134,8 +136,8 @@ def infer():
input=infer_data,
field=['prob', 'id'])
# parse beam result and print
source_dict, target_dict = paddle.dataset.wmt14.get_dict(dict_size)
# parse beam result and print
source_dict, target_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
beam_probs, beam_sentences = parse_beam_search_result(beam_result,
target_dict)
for i in xrange(args.infer_data_num):
......@@ -147,7 +149,7 @@ def infer():
def main():
paddle.init(use_gpu=False, trainer_count=1)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
infer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册