提交 97d11e43 编写于 作者: G gongweibao

add

上级 55cf75df
......@@ -10,9 +10,9 @@ END_MARK = "<e>"
UNK_MARK = "<unk>"
'''
START_MARK = "<_GO>"
END_MARK = "<_EOS>"
UNK_MARK = "<_UNK>"
START_MARK = "_GO"
END_MARK = "_EOS"
UNK_MARK = "_UNK"
class DataLoader(object):
def __init__(self,
......
......@@ -280,8 +280,9 @@ def main():
ts = time.time()
total = 0
pass_start_time = time.time()
#print len(train_reader)
for batch_id, data in enumerate(train_reader):
print len(data)
#print len(data)
if len(data) != args.batch_size:
continue
......@@ -415,10 +416,11 @@ def main():
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
#print "/root/data/nist06n/data-%d/part-*" % (args.task_index),
train_reader = data_util.DataLoader(
src_vocab_fpath="/root/data/nist06n/cn_30001.dict",
trg_vocab_fpath="/root/data/nist06n/en_30001.dict",
fpattern="/root/data/nist06/data-%d/part-*" % (args.task_index),
fpattern="/root/data/nist06n/data-%d/part-*" % (args.task_index),
batch_size=args.batch_size,
token_batch_size=TrainTaskConfig.token_batch_size,
sort_by_length=TrainTaskConfig.sort_by_length,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册