提交 84152a09 编写于 作者: Y Yibing Liu

Disable splitting long sentence in infer

上级 b3ba7fda
......@@ -185,6 +185,9 @@ class AsyncDataReader(object):
corresponding description file.
drop_frame_len (int): Samples whose label length above the value will be
dropped.(Using '-1' to disable the policy)
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
(Assign -1 to disable split)
proc_num (int): Number of processes for processing data.
sample_buffer_size (int): Buffer size to indicate the maximum samples
cached.
......@@ -204,6 +207,7 @@ class AsyncDataReader(object):
feature_file_list,
label_file_list="",
drop_frame_len=512,
split_sentence_threshold=512,
proc_num=10,
sample_buffer_size=1024,
sample_info_buffer_size=1024,
......@@ -214,6 +218,7 @@ class AsyncDataReader(object):
self._feature_file_list = feature_file_list
self._label_file_list = label_file_list
self._drop_frame_len = drop_frame_len
self._split_sentence_threshold = split_sentence_threshold
self._shuffle_block_num = shuffle_block_num
self._block_info_list = None
self._rng = random.Random(random_seed)
......@@ -262,7 +267,8 @@ class AsyncDataReader(object):
map(lambda info: info[0], bucket_block_info),
map(lambda info: info[1], bucket_block_info),
map(lambda info: info[2], bucket_block_info),
map(lambda info: info[3], bucket_block_info)))
map(lambda info: info[3], bucket_block_info),
split_sentence_threshold=self._split_sentence_threshold))
# @TODO make this configurable
def set_transformers(self, transformers):
......
......@@ -207,8 +207,11 @@ def infer_from_ckpt(args):
label_t = fluid.LoDTensor()
# infer data reader
infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader = reader.AsyncDataReader(
args.infer_feature_lst,
args.infer_label_lst,
drop_frame_len=-1,
split_sentence_threshold=-1)
infer_data_reader.set_transformers(ltrans)
infer_costs, infer_accs = [], []
total_edit_dist, total_ref_len = 0.0, 0
......
......@@ -187,7 +187,7 @@ def train(args):
return -1.0, -1.0
# test data reader
test_data_reader = reader.AsyncDataReader(args.val_feature_lst,
args.val_label_lst)
args.val_label_lst, -1)
test_data_reader.set_transformers(ltrans)
test_costs, test_accs = [], []
for batch_id, batch_data in enumerate(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册