From 97d11e43c92ac87521dfde2aab193b865cf33d17 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 26 Apr 2018 03:29:40 +0000 Subject: [PATCH] add --- .../transformer_nist_base/data_util.py | 6 +++--- .../transformer_nist_base/nmt_fluid.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/fluid/neural_machine_translation/transformer_nist_base/data_util.py b/fluid/neural_machine_translation/transformer_nist_base/data_util.py index ffa11a7a..26d974e6 100644 --- a/fluid/neural_machine_translation/transformer_nist_base/data_util.py +++ b/fluid/neural_machine_translation/transformer_nist_base/data_util.py @@ -10,9 +10,9 @@ END_MARK = "" UNK_MARK = "" ''' -START_MARK = "<_GO>" -END_MARK = "<_EOS>" -UNK_MARK = "<_UNK>" +START_MARK = "_GO" +END_MARK = "_EOS" +UNK_MARK = "_UNK" class DataLoader(object): def __init__(self, diff --git a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py index bedddffb..8d467c39 100644 --- a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py +++ b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py @@ -280,8 +280,9 @@ def main(): ts = time.time() total = 0 pass_start_time = time.time() + #print len(train_reader) for batch_id, data in enumerate(train_reader): - print len(data) + #print len(data) if len(data) != args.batch_size: continue @@ -415,10 +416,11 @@ def main(): position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model), place) + #print "/root/data/nist06n/data-%d/part-*" % (args.task_index), train_reader = data_util.DataLoader( src_vocab_fpath="/root/data/nist06n/cn_30001.dict", trg_vocab_fpath="/root/data/nist06n/en_30001.dict", - fpattern="/root/data/nist06/data-%d/part-*" % (args.task_index), + fpattern="/root/data/nist06n/data-%d/part-*" % (args.task_index), batch_size=args.batch_size, token_batch_size=TrainTaskConfig.token_batch_size, sort_by_length=TrainTaskConfig.sort_by_length, -- GitLab