提交 a3e27970 编写于 作者: X Xinghai Sun

Add multiprocess version of xmap_reader to speedup training.

Add seqbin data parser to adapt to internal 1w data training.
上级 0173cc5c
...@@ -5,6 +5,8 @@ from __future__ import print_function ...@@ -5,6 +5,8 @@ from __future__ import print_function
import numpy as np import numpy as np
import io import io
import struct
import re
import soundfile import soundfile
import resampy import resampy
from scipy import signal from scipy import signal
...@@ -114,6 +116,46 @@ class AudioSegment(object): ...@@ -114,6 +116,46 @@ class AudioSegment(object):
data = sndfile.read(frames=end_frame - start_frame, dtype='float32') data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate) return cls(data, sample_rate)
@classmethod
def from_sequence_file(cls, filepath):
"""Create audio segment from sequence file.
:param filepath: Filepath of sequence file.
:type filepath: basestring
:return: Audio segment instance.
:rtype: AudioSegment
"""
# parse filepath
matches = re.match(r"(.+\.seqbin)_(\d+)", filepath)
if matches is None:
raise IOError("File type of %s is not supported" % filepath)
filename = matches.group(1)
fileno = int(matches.group(2))
# read headers
f = open(filename, 'rb')
version = f.read(4)
num_utterances = struct.unpack("i", f.read(4))[0]
bytes_per_header = struct.unpack("i", f.read(4))[0]
header_bytes = f.read(bytes_per_header * (num_utterances + 1))
header = [
struct.unpack("i", header_bytes[bytes_per_header * i:
bytes_per_header * (i + 1)])[0]
for i in range(num_utterances + 1)
]
# read audio bytes
f.seek(header[fileno - 1])
audio_bytes = f.read(header[fileno] - header[fileno - 1])
f.close()
# create audio segment
try:
return cls.from_bytes(audio_bytes)
except Exception as e:
samples = np.frombuffer(audio_bytes, dtype='int16')
return cls(samples=samples, sample_rate=8000)
@classmethod @classmethod
def from_bytes(cls, bytes): def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples. """Create audio segment from a byte string containing audio samples.
......
...@@ -7,11 +7,13 @@ from __future__ import print_function ...@@ -7,11 +7,13 @@ from __future__ import print_function
import random import random
import tarfile import tarfile
import re
import multiprocessing import multiprocessing
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
from threading import local from threading import local
from data_utils.utility import read_manifest from data_utils.utility import read_manifest
from data_utils.utility import xmap_readers_mp
from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment from data_utils.speech import SpeechSegment
...@@ -100,7 +102,14 @@ class DataGenerator(object): ...@@ -100,7 +102,14 @@ class DataGenerator(object):
transcription. transcription.
:rtype: tuple of (2darray, list) :rtype: tuple of (2darray, list)
""" """
speech_segment = SpeechSegment.from_file(filename, transcript) if filename.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(filename), transcript)
elif re.findall(r".seqbin_\d+$", filename):
speech_segment = SpeechSegment.from_sequence_file(filename,
transcript)
else:
speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment) self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment) specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
specgram = self._normalizer.apply(specgram) specgram = self._normalizer.apply(specgram)
...@@ -231,27 +240,23 @@ class DataGenerator(object): ...@@ -231,27 +240,23 @@ class DataGenerator(object):
result[tarinfo.name] = tarinfo result[tarinfo.name] = tarinfo
return f, result return f, result
def _get_file_object(self, file): def _subfile_from_tar(self, file):
"""Get file object by file path. """Get subfile object from tar.
If file startwith tar, it will return a tar file object It will return a subfile object from tar file
and cached tar file info for next reading request. and cached tar file info for next reading request.
It will return file directly, if the type of file is not str.
""" """
if file.startswith('tar:'): tarpath, filename = file.split(':', 1)[1].split('#', 1)
tarpath, filename = file.split(':', 1)[1].split('#', 1) if 'tar2info' not in self._local_data.__dict__:
if 'tar2info' not in self._local_data.__dict__: self._local_data.tar2info = {}
self._local_data.tar2info = {} if 'tar2object' not in self._local_data.__dict__:
if 'tar2object' not in self._local_data.__dict__: self._local_data.tar2object = {}
self._local_data.tar2object = {} if tarpath not in self._local_data.tar2info:
if tarpath not in self._local_data.tar2info: object, infoes = self._parse_tar(tarpath)
object, infoes = self._parse_tar(tarpath) self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2info[tarpath] = infoes self._local_data.tar2object[tarpath] = object
self._local_data.tar2object[tarpath] = object return self._local_data.tar2object[tarpath].extractfile(
return self._local_data.tar2object[tarpath].extractfile( self._local_data.tar2info[tarpath][filename])
self._local_data.tar2info[tarpath][filename])
else:
return open(file, 'r')
def _instance_reader_creator(self, manifest): def _instance_reader_creator(self, manifest):
""" """
...@@ -266,13 +271,12 @@ class DataGenerator(object): ...@@ -266,13 +271,12 @@ class DataGenerator(object):
for instance in manifest: for instance in manifest:
yield instance yield instance
def mapper(instance): return xmap_readers_mp(
return self.process_utterance( lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
self._get_file_object(instance["audio_filepath"]), reader,
instance["text"]) self._num_threads,
4096,
return paddle.reader.xmap_readers( order=True)
mapper, reader, self._num_threads, 1024, order=True)
def _padding_batch(self, batch, padding_to=-1, flatten=False): def _padding_batch(self, batch, padding_to=-1, flatten=False):
""" """
......
...@@ -44,12 +44,26 @@ class SpeechSegment(AudioSegment): ...@@ -44,12 +44,26 @@ class SpeechSegment(AudioSegment):
:type filepath: basestring|file :type filepath: basestring|file
:param transcript: Transcript text for the speech. :param transcript: Transcript text for the speech.
:type transript: basestring :type transript: basestring
:return: Audio segment instance. :return: Speech segment instance.
:rtype: AudioSegment :rtype: SpeechSegment
""" """
audio = AudioSegment.from_file(filepath) audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript) return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def from_sequence_file(cls, filepath, transcript):
"""Create speech segment from sequence file and transcript.
:param filepath: Filepath of sequence file.
:type filepath: basestring
:param transcript: Transcript text for the speech.
:type transript: basestring
:return: Speech segment instance.
:rtype: SpeechSegment
"""
audio = AudioSegment.from_sequence_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod @classmethod
def from_bytes(cls, bytes, transcript): def from_bytes(cls, bytes, transcript):
"""Create speech segment from a byte string and corresponding """Create speech segment from a byte string and corresponding
...@@ -59,8 +73,8 @@ class SpeechSegment(AudioSegment): ...@@ -59,8 +73,8 @@ class SpeechSegment(AudioSegment):
:type bytes: str :type bytes: str
:param transcript: Transcript text for the speech. :param transcript: Transcript text for the speech.
:type transript: basestring :type transript: basestring
:return: Audio segment instance. :return: Speech segment instance.
:rtype: AudioSegment :rtype: Speech Segment
""" """
audio = AudioSegment.from_bytes(bytes) audio = AudioSegment.from_bytes(bytes)
return cls(audio.samples, audio.sample_rate, transcript) return cls(audio.samples, audio.sample_rate, transcript)
......
...@@ -7,6 +7,9 @@ import json ...@@ -7,6 +7,9 @@ import json
import codecs import codecs
import os import os
import tarfile import tarfile
import time
from Queue import Queue
from multiprocessing import Process, Manager
from paddle.v2.dataset.common import md5file from paddle.v2.dataset.common import md5file
...@@ -61,3 +64,98 @@ def unpack(filepath, target_dir, rm_tar=False): ...@@ -61,3 +64,98 @@ def unpack(filepath, target_dir, rm_tar=False):
tar.close() tar.close()
if rm_tar == True: if rm_tar == True:
os.remove(filepath) os.remove(filepath)
class XmapEndSignal():
pass
def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
"""A multiprocessing pipeline wrapper for the data reader.
:param mapper: Function to map sample.
:type mapper: callable
:param reader: Given data reader.
:type reader: callable
:param process_num: Number of processes in the pipeline
:type process_num: int
:param buffer_size: Maximal buffer size.
:type buffer_size: int
:param order: Reserve the order of samples from the given reader.
:type order: bool
:return: The wrappered reader
:rtype: callable
"""
end_flag = XmapEndSignal()
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for sample in reader():
in_queue.put(sample)
in_queue.put(end_flag)
# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
for order_id, sample in enumerate(reader()):
in_queue.put((order_id, sample))
in_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results to out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
out_queue.put(mapper(sample))
sample = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
order_id, sample = ins
result = mapper(sample)
while order_id != out_order[0]:
time.sleep(0.001)
out_queue.put(result)
out_order[0] += 1
ins = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
def xreader():
# prepare shared memory
manager = Manager()
in_queue = manager.Queue(buffer_size)
out_queue = manager.Queue(buffer_size)
out_order = manager.list([0])
# start a read worker in a process
target = order_read_worker if order else read_worker
p = Process(target=target, args=(reader, in_queue))
p.start()
# start handle_workers with multiple processes
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = [
Process(target=target, args=args) for _ in xrange(process_num)
]
for w in workers:
w.start()
# get results
sample = out_queue.get()
while not isinstance(sample, XmapEndSignal):
yield sample
sample = out_queue.get()
finish = 1
while finish < process_num:
sample = out_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
yield sample
return xreader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册