提交 606cb22e 编写于 作者: Z zhxfl

Merge remote-tracking branch 'upstream/develop' into fix-661

......@@ -15,13 +15,12 @@ from multiprocessing import Manager, Process
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import SharedNDArray, SharedMemoryPoolManager
from data_utils.util import DaemonProcessGroup, batch_to_ndarray
from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal
from data_utils.util import CriticalException, ForceExitWrapper
class SampleInfo(object):
"""SampleInfo holds the necessary information to load a sample from disk.
Args:
feature_bin_path (str): File containing the feature data.
feature_start (int): Start position of the sample's feature data.
......@@ -54,6 +53,7 @@ class SampleInfoBucket(object):
data, sample start position, sample byte number etc.) to access samples'
feature data and the same with the label description file. SampleInfoBucket
is the minimum unit to do shuffle.
Args:
feature_bin_paths (list|tuple): Files containing the binary feature
data.
......@@ -67,8 +67,8 @@ class SampleInfoBucket(object):
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + \
rand() % split_perturb).
(split_sub_sentence_len
+ rand() % split_perturb).
"""
def __init__(self,
......@@ -160,9 +160,14 @@ class SampleInfoBucket(object):
return sample_info_list
class EpochEndSignal():
pass
class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation.
Args:
feature_file_list (str): File containing paths of feature data file and
corresponding description file.
......@@ -206,17 +211,12 @@ class AsyncDataReader(object):
self.generate_bucket_list(True)
self._order_id = 0
self._manager = Manager()
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._proc_num = proc_num
if self._proc_num <= 2:
raise ValueError("Value of `proc_num` should be greater than 2.")
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
# buffer queue
self._sample_info_queue = self._manager.Queue(sample_info_buffer_size)
self._sample_queue = self._manager.Queue(sample_buffer_size)
self._batch_queue = self._manager.Queue(batch_buffer_size)
def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
......@@ -250,21 +250,13 @@ class AsyncDataReader(object):
def set_transformers(self, transformers):
self._transformers = transformers
def recycle(self, *args):
for shared_ndarray in args:
if not isinstance(shared_ndarray, SharedNDArray):
raise Value("Only support recycle SharedNDArray object.")
shared_ndarray.recycle(self._pool_manager.pool)
def _start_async_processing(self):
def _sample_generator(self):
sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
for sample_info_bucket in self._bucket_list:
try:
sample_info_list = \
......@@ -277,14 +269,13 @@ class AsyncDataReader(object):
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1
for i in xrange(self._sample_proc_num):
for i in xrange(self._proc_num):
sample_info_queue.put(EpochEndSignal())
feeding_proc = DaemonProcessGroup(
proc_num=1,
target=ordered_feeding_task,
args=(self._sample_info_queue, ))
feeding_proc.start_all()
feeding_thread = Thread(
target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
......@@ -312,11 +303,12 @@ class AsyncDataReader(object):
sample_info.feature_size)
assert sample_info.feature_frame_num \
* sample_info.feature_dim * 4 == len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
* sample_info.feature_dim * 4 \
== len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start,
......@@ -360,83 +352,83 @@ class AsyncDataReader(object):
sample_queue.put(EpochEndSignal())
out_order = self._manager.list([0])
args = (self._sample_info_queue, self._sample_queue, out_order)
sample_proc = DaemonProcessGroup(
proc_num=self._sample_proc_num,
target=ordered_processing_task,
args=args)
sample_proc.start_all()
args = (sample_info_queue, sample_queue, out_order)
workers = [
Process(
target=ordered_processing_task, args=args)
for _ in xrange(self._proc_num)
]
def batch_iterator(self, batch_size, minimum_batch_size):
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_queue, batch_queue, pool):
def conv_to_shared(ndarray):
while self._force_exit == False:
try:
(name, shared_ndarray) = pool.popitem()
except Exception as e:
time.sleep(0.001)
else:
shared_ndarray.copy(ndarray)
return shared_ndarray
for w in workers:
w.daemon = True
w.start()
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
finished_proc_num = 0
batch_samples = []
lod = [0]
done_num = 0
while done_num < self._sample_proc_num:
sample = sample_queue.get()
while self._force_exit == False:
try:
sample = sample_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal):
done_num += 1
else:
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
feature, label = batch_to_ndarray(batch_samples, lod)
feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
finished_proc_num += 1
if finished_proc_num >= self._proc_num:
break
else:
continue
batch_queue.put((feature, label, lod))
batch_samples = []
lod = [0]
yield sample
if len(batch_samples) >= minimum_batch_size:
(feature, label) = batch_to_ndarray(batch_samples, lod)
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
for sample in sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
(batch_feature, batch_label) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]
batch_queue.put((feature, label, lod))
if len(batch_samples) >= minimum_batch_size:
(batch_feature, batch_label) = batch_to_ndarray(batch_samples,
lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_queue.put(EpochEndSignal())
self._start_async_processing()
batch_queue = Queue.Queue(self._batch_buffer_size)
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager)
assembling_proc = DaemonProcessGroup(
proc_num=1,
assembling_thread = Thread(
target=batch_assembling_task,
args=(self._sample_queue, self._batch_queue,
self._pool_manager.pool))
assembling_proc.start_all()
args=(self._sample_generator, batch_queue))
assembling_thread.daemon = True
assembling_thread.start()
while self._force_exit == False:
try:
batch_data = self._batch_queue.get_nowait()
batch_data = batch_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data
# clean the shared memory
del self._pool_manager
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys, time
import sys
from six import reraise
from tblib import Traceback
from multiprocessing import Manager, Process
import posix_ipc, mmap
import numpy as np
......@@ -37,19 +35,6 @@ def lodtensor_to_ndarray(lod_tensor):
return ret, lod_tensor.lod()
def batch_to_ndarray(batch_samples, lod):
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
def split_infer_result(infer_seq, lod):
infer_batch = []
for i in xrange(0, len(lod[0]) - 1):
......@@ -57,127 +42,10 @@ def split_infer_result(infer_seq, lod):
return infer_batch
class DaemonProcessGroup(object):
def __init__(self, proc_num, target, args):
self._proc_num = proc_num
self._workers = [
Process(
target=target, args=args) for _ in xrange(self._proc_num)
]
def start_all(self):
for w in self._workers:
w.daemon = True
w.start()
@property
def proc_num(self):
return self._proc_num
class EpochEndSignal(object):
pass
class CriticalException(Exception):
pass
class SharedNDArray(object):
"""SharedNDArray utilizes shared memory to avoid data serialization when
data object shared among different processes. We can reconstruct the
`ndarray` when memory address, shape and dtype provided.
Args:
name (str): Address name of shared memory.
whether_verify (bool): Whether to validate the writing operation.
"""
def __init__(self, name, whether_verify=False):
self._name = name
self._shm = None
self._buf = None
self._array = np.zeros(1, dtype=np.float32)
self._inited = False
self._whether_verify = whether_verify
def zeros_like(self, shape, dtype):
size = int(np.prod(shape)) * np.dtype(dtype).itemsize
if self._inited:
self._shm = posix_ipc.SharedMemory(self._name)
else:
self._shm = posix_ipc.SharedMemory(
self._name, posix_ipc.O_CREAT, size=size)
self._buf = mmap.mmap(self._shm.fd, size)
self._array = np.ndarray(shape, dtype, self._buf, order='C')
def copy(self, ndarray):
size = int(np.prod(ndarray.shape)) * np.dtype(ndarray.dtype).itemsize
self.zeros_like(ndarray.shape, ndarray.dtype)
self._array[:] = ndarray
self._buf.flush()
self._inited = True
if self._whether_verify:
shm = posix_ipc.SharedMemory(self._name)
buf = mmap.mmap(shm.fd, size)
array = np.ndarray(ndarray.shape, ndarray.dtype, buf, order='C')
np.testing.assert_array_equal(array, ndarray)
@property
def ndarray(self):
return self._array
def recycle(self, pool):
self._buf.close()
self._shm.close_fd()
self._inited = False
pool[self._name] = self
def __getstate__(self):
return (self._name, self._array.shape, self._array.dtype, self._inited,
self._whether_verify)
def __setstate__(self, state):
self._name = state[0]
self._inited = state[3]
self.zeros_like(state[1], state[2])
self._whether_verify = state[4]
class SharedMemoryPoolManager(object):
"""SharedMemoryPoolManager maintains a multiprocessing.Manager.dict object.
All available addresses are allocated once and will be reused. Though this
class is not process-safe, the pool can be shared between processes. All
shared memory should be unlinked before the main process exited.
Args:
pool_size (int): Size of shared memory pool.
manager (dict): A multiprocessing.Manager object, the pool is
maintained by the proxy process.
name_prefix (str): Address prefix of shared memory.
"""
def __init__(self, pool_size, manager, name_prefix='/deep_asr'):
self._names = []
self._dict = manager.dict()
self._time_prefix = time.strftime('%Y%m%d%H%M%S')
for i in xrange(pool_size):
name = name_prefix + '_' + self._time_prefix + '_' + str(i)
self._dict[name] = SharedNDArray(name)
self._names.append(name)
@property
def pool(self):
return self._dict
def __del__(self):
for name in self._names:
# have to unlink the shared memory
posix_ipc.unlink_shared_memory(name)
def suppress_signal(signo, stack_frame):
pass
......
......@@ -21,14 +21,15 @@ using fst::StdArc;
Decoder::Decoder(std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename) {
std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale) {
const char* usage =
"Decode, reading log-likelihoods (of transition-ids or whatever symbol "
"is on the graph) as matrices.";
kaldi::ParseOptions po(usage);
binary = true;
acoustic_scale = 1.5;
this->acoustic_scale = acoustic_scale;
allow_partial = true;
kaldi::FasterDecoderOptions decoder_opts;
decoder_opts.Register(&po, true); // true == include obscure settings.
......
......@@ -29,7 +29,8 @@ class Decoder {
public:
Decoder(std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename);
std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale);
~Decoder();
// Interface to accept the scores read from specifier and return
......
......@@ -23,7 +23,7 @@ PYBIND11_MODULE(post_decode_faster, m) {
m.doc() = "Decoder for Deep ASR model";
py::class_<Decoder>(m, "Decoder")
.def(py::init<std::string, std::string, std::string>())
.def(py::init<std::string, std::string, std::string, kaldi::BaseFloat>())
.def("decode",
(std::vector<std::string> (Decoder::*)(std::string)) &
Decoder::decode,
......
......@@ -8,7 +8,7 @@ import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader
import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray
from data_utils.util import split_infer_result
......@@ -79,12 +79,13 @@ def infer(args):
trans_splice.TransSplice()
]
infer_data_reader = reader.DataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader.set_transformers(ltrans)
feature_t = fluid.LoDTensor()
one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next()
(features, labels, lod) = one_batch
feature_t.set(features, place)
feature_t.set_lod([lod])
......
......@@ -106,6 +106,11 @@ def parse_args():
type=str,
default="./decoder/logprior",
help="The log prior probs for training data. (default: %(default)s)")
parser.add_argument(
'--acoustic_scale',
type=float,
default=0.2,
help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
args = parser.parse_args()
return args
......@@ -143,6 +148,10 @@ def infer_from_ckpt(args):
# load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint)
# init decoder
decoder = Decoder(args.vocabulary, args.graphs, args.log_prior,
args.acoustic_scale)
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
......@@ -162,12 +171,10 @@ def infer_from_ckpt(args):
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
infer_data_reader.recycle(features, labels, lod)
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
results = exe.run(infer_program,
feed={"feature": feature_t,
......@@ -181,7 +188,7 @@ def infer_from_ckpt(args):
infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch):
key = "utter#%d" % (batch_id * args.batch_size + index)
print(key, ": ", decoder.decode(key, sample), "\n")
print(key, ": ", decoder.decode(key, sample).encode("utf8"), "\n")
print(np.mean(infer_costs), np.mean(infer_accs))
......
......@@ -169,14 +169,12 @@ def profile(args):
frames_seen = 0
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
frames_seen += lod.ndarray[-1]
data_reader.recycle(features, labels, lod)
frames_seen += lod[-1]
outs = exe.run(fluid.default_main_program(),
feed={"feature": feature_t,
......
......@@ -193,12 +193,10 @@ def train(args):
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
test_data_reader.recycle(features, labels, lod)
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
cost, acc = exe.run(test_program,
feed={"feature": feature_t,
......@@ -221,12 +219,10 @@ def train(args):
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
train_data_reader.recycle(features, labels, lod)
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
outs = exe.run(fluid.default_main_program(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册