提交 555b2dfd 编写于 作者: L Luo Tao

add seqtext_print for seqToseq demo

上级 b25c5124
......@@ -126,7 +126,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
def main():
paddle.init(use_gpu=False, trainer_count=1)
is_generating = True
is_generating = False
# source and target dict dim.
dict_size = 30000
......@@ -167,16 +167,47 @@ def main():
# generate a english sequence to french
else:
gen_creator = paddle.dataset.wmt14.test(dict_size)
# use the first 3 samples for generation
gen_creator = paddle.dataset.wmt14.gen(dict_size)
gen_data = []
gen_num = 3
for item in gen_creator():
gen_data.append((item[0], ))
if len(gen_data) == 3:
if len(gen_data) == gen_num:
break
beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
# get the pretrained model, whose bleu = 26.92
parameters = paddle.dataset.wmt14.model()
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
# prob is the prediction probabilities, and id is the prediction word.
beam_result = paddle.infer(
output_layer=beam_gen,
parameters=parameters,
input=gen_data,
field=['prob', 'id'])
# get the dictionary
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
# the delimited element of generated sequences is -1,
# the first element of each generated sequence is the sequence length
seq_list = []
seq = []
for w in beam_result[1]:
if w != -1:
seq.append(w)
else:
seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]]))
seq = []
prob = beam_result[0]
beam_size = 3
for i in xrange(gen_num):
print "\n*******************************************************\n"
print "src:", ' '.join(
[src_dict.get(w) for w in gen_data[i][0]]), "\n"
for j in xrange(beam_size):
print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]
if __name__ == '__main__':
......
......@@ -26,7 +26,7 @@ URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/de
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
......@@ -108,6 +108,11 @@ def test(dict_size):
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
def gen(dict_size):
return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'gen/gen', dict_size)
def model():
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
with gzip.open(tar_file, 'r') as f:
......@@ -115,10 +120,15 @@ def model():
return parameters
def trg_dict(dict_size):
def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
return trg_dict
if reverse:
src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()}
return src_dict, trg_dict
def fetch():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册