提交 c7b188c4 编写于 作者: R ranqiu

Fix bugs of external_memory.py and infer.py

上级 f456031b
...@@ -113,7 +113,7 @@ class ExternalMemory(object): ...@@ -113,7 +113,7 @@ class ExternalMemory(object):
boot_layer=self.initial_weight) boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation( interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name, name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight], input=[last_addressing_weight, addressing_weight],
weight=paddle.layer.expand(input=gate, expand_as=addressing_weight)) weight=paddle.layer.expand(input=gate, expand_as=addressing_weight))
return interpolated_weight return interpolated_weight
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import distutils.util import distutils.util
import argparse import argparse
import gzip import gzip
import random
import paddle.v2 as paddle import paddle.v2 as paddle
from external_memory import ExternalMemory from external_memory import ExternalMemory
...@@ -118,10 +119,11 @@ def infer(): ...@@ -118,10 +119,11 @@ def infer():
infer_data = [] infer_data = []
random.seed(0) # for keeping consitancy for multiple runs random.seed(0) # for keeping consitancy for multiple runs
bounded_memory_perturbation = [[ bounded_memory_perturbation = [[
random.gauss(0, memory_perturb_stddev) for i in xrange(args.hidden_size) random.gauss(0, args.memory_perturb_stddev)
for i in xrange(args.hidden_size)
] for j in xrange(args.memory_slot_num)] ] for j in xrange(args.memory_slot_num)]
test_append_reader = reader_append_wrapper( test_append_reader = reader_append_wrapper(
reader=paddle.dataset.wmt14.test(dict_size), reader=paddle.dataset.wmt14.test(args.dict_size),
append_tuple=(bounded_memory_perturbation, )) append_tuple=(bounded_memory_perturbation, ))
for i, item in enumerate(test_append_reader()): for i, item in enumerate(test_append_reader()):
if i < args.infer_data_num: if i < args.infer_data_num:
...@@ -135,7 +137,7 @@ def infer(): ...@@ -135,7 +137,7 @@ def infer():
field=['prob', 'id']) field=['prob', 'id'])
# parse beam result and print # parse beam result and print
source_dict, target_dict = paddle.dataset.wmt14.get_dict(dict_size) source_dict, target_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
beam_probs, beam_sentences = parse_beam_search_result(beam_result, beam_probs, beam_sentences = parse_beam_search_result(beam_result,
target_dict) target_dict)
for i in xrange(args.infer_data_num): for i in xrange(args.infer_data_num):
...@@ -147,7 +149,7 @@ def infer(): ...@@ -147,7 +149,7 @@ def infer():
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
infer() infer()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册