提交 84d28c30 编写于 作者: Z zhxfl

split long sentence, drop sentence close by -1

上级 85d8e5c8
......@@ -59,10 +59,19 @@ class SampleInfoBucket(object):
label_bin_paths (list|tuple): Files containing the binary label data.
label_desc_paths (list|tuple): Files containing the description of
samples' label data.
split_perturb(int): split long sentence' perturb sub-sentence length value.
split_sentence_threshold(int): sentence length large than split_sentence_threshold trigger split operator.
split_sub_sentence_len(int): sub-sentence length is equal to (split_sub_sentence_len + rand() % split_perturb).
"""
def __init__(self, feature_bin_paths, feature_desc_paths, label_bin_paths,
label_desc_paths):
def __init__(self,
feature_bin_paths,
feature_desc_paths,
label_bin_paths,
label_desc_paths,
split_perturb=50,
split_sentence_threshold=512,
split_sub_sentence_len=256):
block_num = len(label_bin_paths)
assert len(label_desc_paths) == block_num
assert len(feature_bin_paths) == block_num
......@@ -73,6 +82,9 @@ class SampleInfoBucket(object):
self._feature_desc_paths = feature_desc_paths
self._label_bin_paths = label_bin_paths
self._label_desc_paths = label_desc_paths
self._split_perturb = split_perturb
self._split_sentence_threshold = split_sentence_threshold
self._split_sub_sentence_len = split_sub_sentence_len
def generate_sample_info_list(self):
sample_info_list = []
......@@ -100,11 +112,40 @@ class SampleInfoBucket(object):
label_size = int(label_desc_split[3])
label_frame_num = int(label_desc_split[4])
sample_info_list.append(
SampleInfo(feature_bin_path, feature_start, feature_size,
feature_frame_num, feature_dim, label_bin_path,
label_start, label_size, label_frame_num))
if self._split_sentence_threshold == -1 or self._split_perturb == -1 or self._split_sub_sentence_len == -1 or self._split_sentence_threshold >= feature_frame_num:
sample_info_list.append(
SampleInfo(feature_bin_path, feature_start,
feature_size, feature_frame_num, feature_dim,
label_bin_path, label_start, label_size,
label_frame_num))
#split sentence
else:
cur_frame_pos = 0
cur_frame_len = 0
remain_frame_num = feature_frame_num
while True:
if remain_frame_num > self._split_sentence_threshold:
cur_frame_len = self._split_sub_sentence_len + random.randint(
0, self._split_perturb)
if cur_frame_len > remain_frame_num:
cur_frame_len = remain_frame_num
else:
cur_frame_len = remain_frame_num
sample_info_list.append(
SampleInfo(
feature_bin_path, feature_start + cur_frame_pos
* feature_dim * 4, cur_frame_len * feature_dim *
4, cur_frame_len, feature_dim, label_bin_path,
label_start + cur_frame_pos * 4, cur_frame_len *
4, cur_frame_len))
remain_frame_num -= cur_frame_len
cur_frame_pos += cur_frame_len
if remain_frame_num <= 0:
break
print("generate_sample_info_list size ", len(sample_info_list))
return sample_info_list
......@@ -143,7 +184,7 @@ class DataReader(object):
sample_buffer_size=1024,
sample_info_buffer_size=1024,
batch_buffer_size=1024,
shuffle_block_num=1,
shuffle_block_num=10,
random_seed=0):
self._feature_file_list = feature_file_list
self._label_file_list = label_file_list
......@@ -260,7 +301,8 @@ class DataReader(object):
time.sleep(0.001)
# drop long sentence
if self._drop_frame_len >= sample_data[0].shape[0]:
if self._drop_frame_len == -1 or self._drop_frame_len >= sample_data[
0].shape[0]:
sample_queue.put(sample_data)
out_order[0] += 1
......@@ -281,7 +323,6 @@ class DataReader(object):
w.start()
finished_process_num = 0
while finished_process_num < self._process_num:
sample = sample_queue.get()
if isinstance(sample, EpochEndSignal):
......
......@@ -173,7 +173,7 @@ def train(args):
trans_splice.TransSplice()
]
data_reader = reader.DataReader(args.feature_lst, args.label_lst)
data_reader = reader.DataReader(args.feature_lst, args.label_lst, -1)
data_reader.set_transformers(ltrans)
res_feature = fluid.LoDTensor()
......@@ -198,7 +198,8 @@ def train(args):
fetch_list=[avg_cost] + accuracy.metrics,
return_numpy=False)
train_acc = accuracy.eval(exe)
print("acc:", lodtensor_to_ndarray(loss))
print("pass_id", pass_id, "batch_id", batch_id, "acc:",
lodtensor_to_ndarray(loss))
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册