提交 b420c007 编写于 作者: Y Yu Yang

Add pickle

上级 520612bf
......@@ -2,7 +2,7 @@ import glob
import os
import random
import tarfile
import time
import cPickle
class SortType(object):
......@@ -179,7 +179,8 @@ class DataReader(object):
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
seed=0):
seed=0,
pkl_filename=None):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True
if trg_vocab_fpath is not None:
......@@ -195,8 +196,23 @@ class DataReader(object):
self._min_length = min_length
self._max_length = max_length
self._delimiter = delimiter
if pkl_filename is None:
self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
else:
try:
with open(pkl_filename, 'r') as f:
self._src_seq_ids, self._trg_seq_ids, self._sample_infos = cPickle.load(
f)
except:
self.load_src_trg_ids(end_mark, fpattern, start_mark, tarfile,
unk_mark)
with open(pkl_filename, 'w') as f:
cPickle.dump((self._src_seq_ids, self._trg_seq_ids,
self._sample_infos), f,
cPickle.HIGHEST_PROTOCOL)
self._random = random.Random(x=seed)
def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册