提交 84152a09 编写于 作者: Y Yibing Liu

Disable splitting long sentence in infer

上级 b3ba7fda
...@@ -185,6 +185,9 @@ class AsyncDataReader(object): ...@@ -185,6 +185,9 @@ class AsyncDataReader(object):
corresponding description file. corresponding description file.
drop_frame_len (int): Samples whose label length above the value will be drop_frame_len (int): Samples whose label length above the value will be
dropped.(Using '-1' to disable the policy) dropped.(Using '-1' to disable the policy)
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
(Assign -1 to disable split)
proc_num (int): Number of processes for processing data. proc_num (int): Number of processes for processing data.
sample_buffer_size (int): Buffer size to indicate the maximum samples sample_buffer_size (int): Buffer size to indicate the maximum samples
cached. cached.
...@@ -204,6 +207,7 @@ class AsyncDataReader(object): ...@@ -204,6 +207,7 @@ class AsyncDataReader(object):
feature_file_list, feature_file_list,
label_file_list="", label_file_list="",
drop_frame_len=512, drop_frame_len=512,
split_sentence_threshold=512,
proc_num=10, proc_num=10,
sample_buffer_size=1024, sample_buffer_size=1024,
sample_info_buffer_size=1024, sample_info_buffer_size=1024,
...@@ -214,6 +218,7 @@ class AsyncDataReader(object): ...@@ -214,6 +218,7 @@ class AsyncDataReader(object):
self._feature_file_list = feature_file_list self._feature_file_list = feature_file_list
self._label_file_list = label_file_list self._label_file_list = label_file_list
self._drop_frame_len = drop_frame_len self._drop_frame_len = drop_frame_len
self._split_sentence_threshold = split_sentence_threshold
self._shuffle_block_num = shuffle_block_num self._shuffle_block_num = shuffle_block_num
self._block_info_list = None self._block_info_list = None
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
...@@ -262,7 +267,8 @@ class AsyncDataReader(object): ...@@ -262,7 +267,8 @@ class AsyncDataReader(object):
map(lambda info: info[0], bucket_block_info), map(lambda info: info[0], bucket_block_info),
map(lambda info: info[1], bucket_block_info), map(lambda info: info[1], bucket_block_info),
map(lambda info: info[2], bucket_block_info), map(lambda info: info[2], bucket_block_info),
map(lambda info: info[3], bucket_block_info))) map(lambda info: info[3], bucket_block_info),
split_sentence_threshold=self._split_sentence_threshold))
# @TODO make this configurable # @TODO make this configurable
def set_transformers(self, transformers): def set_transformers(self, transformers):
......
...@@ -207,8 +207,11 @@ def infer_from_ckpt(args): ...@@ -207,8 +207,11 @@ def infer_from_ckpt(args):
label_t = fluid.LoDTensor() label_t = fluid.LoDTensor()
# infer data reader # infer data reader
infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst, infer_data_reader = reader.AsyncDataReader(
args.infer_label_lst) args.infer_feature_lst,
args.infer_label_lst,
drop_frame_len=-1,
split_sentence_threshold=-1)
infer_data_reader.set_transformers(ltrans) infer_data_reader.set_transformers(ltrans)
infer_costs, infer_accs = [], [] infer_costs, infer_accs = [], []
total_edit_dist, total_ref_len = 0.0, 0 total_edit_dist, total_ref_len = 0.0, 0
......
...@@ -187,7 +187,7 @@ def train(args): ...@@ -187,7 +187,7 @@ def train(args):
return -1.0, -1.0 return -1.0, -1.0
# test data reader # test data reader
test_data_reader = reader.AsyncDataReader(args.val_feature_lst, test_data_reader = reader.AsyncDataReader(args.val_feature_lst,
args.val_label_lst) args.val_label_lst, -1)
test_data_reader.set_transformers(ltrans) test_data_reader.set_transformers(ltrans)
test_costs, test_accs = [], [] test_costs, test_accs = [], []
for batch_id, batch_data in enumerate( for batch_id, batch_data in enumerate(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册