diff --git a/deep_speech_2/data_utils/audio.py b/deep_speech_2/data_utils/audio.py index 30e25221cd84aa6849061635749188e3bd13d67b..895a7899c0b293a3834af9759d3804da4257cc8a 100644 --- a/deep_speech_2/data_utils/audio.py +++ b/deep_speech_2/data_utils/audio.py @@ -5,6 +5,8 @@ from __future__ import print_function import numpy as np import io +import struct +import re import soundfile import resampy from scipy import signal @@ -114,6 +116,46 @@ class AudioSegment(object): data = sndfile.read(frames=end_frame - start_frame, dtype='float32') return cls(data, sample_rate) + @classmethod + def from_sequence_file(cls, filepath): + """Create audio segment from sequence file. + + :param filepath: Filepath of sequence file. + :type filepath: basestring + :return: Audio segment instance. + :rtype: AudioSegment + """ + # parse filepath + matches = re.match(r"(.+\.seqbin)_(\d+)", filepath) + if matches is None: + raise IOError("File type of %s is not supported" % filepath) + filename = matches.group(1) + fileno = int(matches.group(2)) + + # read headers + f = open(filename, 'rb') + version = f.read(4) + num_utterances = struct.unpack("i", f.read(4))[0] + bytes_per_header = struct.unpack("i", f.read(4))[0] + header_bytes = f.read(bytes_per_header * (num_utterances + 1)) + header = [ + struct.unpack("i", header_bytes[bytes_per_header * i: + bytes_per_header * (i + 1)])[0] + for i in range(num_utterances + 1) + ] + + # read audio bytes + f.seek(header[fileno - 1]) + audio_bytes = f.read(header[fileno] - header[fileno - 1]) + f.close() + + # create audio segment + try: + return cls.from_bytes(audio_bytes) + except Exception as e: + samples = np.frombuffer(audio_bytes, dtype='int16') + return cls(samples=samples, sample_rate=8000) + @classmethod def from_bytes(cls, bytes): """Create audio segment from a byte string containing audio samples. diff --git a/deep_speech_2/data_utils/data.py b/deep_speech_2/data_utils/data.py index 7ddf1f339d6ebe5c7bc40317f9571d200e98d4e5..fca5381758f3a0a97d1d288019f0d909eca2469a 100644 --- a/deep_speech_2/data_utils/data.py +++ b/deep_speech_2/data_utils/data.py @@ -7,11 +7,13 @@ from __future__ import print_function import random import tarfile +import re import multiprocessing import numpy as np import paddle.v2 as paddle from threading import local from data_utils.utility import read_manifest +from data_utils.utility import xmap_readers_mp from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.featurizer.speech_featurizer import SpeechFeaturizer from data_utils.speech import SpeechSegment @@ -100,7 +102,14 @@ class DataGenerator(object): transcription. :rtype: tuple of (2darray, list) """ - speech_segment = SpeechSegment.from_file(filename, transcript) + if filename.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(filename), transcript) + elif re.findall(r".seqbin_\d+$", filename): + speech_segment = SpeechSegment.from_sequence_file(filename, + transcript) + else: + speech_segment = SpeechSegment.from_file(filename, transcript) self._augmentation_pipeline.transform_audio(speech_segment) specgram, text_ids = self._speech_featurizer.featurize(speech_segment) specgram = self._normalizer.apply(specgram) @@ -231,27 +240,23 @@ class DataGenerator(object): result[tarinfo.name] = tarinfo return f, result - def _get_file_object(self, file): - """Get file object by file path. + def _subfile_from_tar(self, file): + """Get subfile object from tar. - If file startwith tar, it will return a tar file object + It will return a subfile object from tar file and cached tar file info for next reading request. - It will return file directly, if the type of file is not str. """ - if file.startswith('tar:'): - tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in self._local_data.__dict__: - self._local_data.tar2info = {} - if 'tar2object' not in self._local_data.__dict__: - self._local_data.tar2object = {} - if tarpath not in self._local_data.tar2info: - object, infoes = self._parse_tar(tarpath) - self._local_data.tar2info[tarpath] = infoes - self._local_data.tar2object[tarpath] = object - return self._local_data.tar2object[tarpath].extractfile( - self._local_data.tar2info[tarpath][filename]) - else: - return open(file, 'r') + tarpath, filename = file.split(':', 1)[1].split('#', 1) + if 'tar2info' not in self._local_data.__dict__: + self._local_data.tar2info = {} + if 'tar2object' not in self._local_data.__dict__: + self._local_data.tar2object = {} + if tarpath not in self._local_data.tar2info: + object, infoes = self._parse_tar(tarpath) + self._local_data.tar2info[tarpath] = infoes + self._local_data.tar2object[tarpath] = object + return self._local_data.tar2object[tarpath].extractfile( + self._local_data.tar2info[tarpath][filename]) def _instance_reader_creator(self, manifest): """ @@ -266,13 +271,12 @@ class DataGenerator(object): for instance in manifest: yield instance - def mapper(instance): - return self.process_utterance( - self._get_file_object(instance["audio_filepath"]), - instance["text"]) - - return paddle.reader.xmap_readers( - mapper, reader, self._num_threads, 1024, order=True) + return xmap_readers_mp( + lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]), + reader, + self._num_threads, + 4096, + order=True) def _padding_batch(self, batch, padding_to=-1, flatten=False): """ diff --git a/deep_speech_2/data_utils/speech.py b/deep_speech_2/data_utils/speech.py index 17d68f315d04b6cc1aae2346df78cf77982cd7bc..623b38c24132519de737dafb1a782293d5db24f7 100644 --- a/deep_speech_2/data_utils/speech.py +++ b/deep_speech_2/data_utils/speech.py @@ -44,12 +44,26 @@ class SpeechSegment(AudioSegment): :type filepath: basestring|file :param transcript: Transcript text for the speech. :type transript: basestring - :return: Audio segment instance. - :rtype: AudioSegment + :return: Speech segment instance. + :rtype: SpeechSegment """ audio = AudioSegment.from_file(filepath) return cls(audio.samples, audio.sample_rate, transcript) + @classmethod + def from_sequence_file(cls, filepath, transcript): + """Create speech segment from sequence file and transcript. + + :param filepath: Filepath of sequence file. + :type filepath: basestring + :param transcript: Transcript text for the speech. + :type transript: basestring + :return: Speech segment instance. + :rtype: SpeechSegment + """ + audio = AudioSegment.from_sequence_file(filepath) + return cls(audio.samples, audio.sample_rate, transcript) + @classmethod def from_bytes(cls, bytes, transcript): """Create speech segment from a byte string and corresponding @@ -59,8 +73,8 @@ class SpeechSegment(AudioSegment): :type bytes: str :param transcript: Transcript text for the speech. :type transript: basestring - :return: Audio segment instance. - :rtype: AudioSegment + :return: Speech segment instance. + :rtype: Speech Segment """ audio = AudioSegment.from_bytes(bytes) return cls(audio.samples, audio.sample_rate, transcript) diff --git a/deep_speech_2/data_utils/utility.py b/deep_speech_2/data_utils/utility.py index da7b66ef2f65699678c09def05ee95fe5c52c52f..40e212c89c3977890d7dc87c5076e3b7f62bc380 100644 --- a/deep_speech_2/data_utils/utility.py +++ b/deep_speech_2/data_utils/utility.py @@ -7,6 +7,9 @@ import json import codecs import os import tarfile +import time +from Queue import Queue +from multiprocessing import Process, Manager from paddle.v2.dataset.common import md5file @@ -61,3 +64,98 @@ def unpack(filepath, target_dir, rm_tar=False): tar.close() if rm_tar == True: os.remove(filepath) + + +class XmapEndSignal(): + pass + + +def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): + """A multiprocessing pipeline wrapper for the data reader. + + :param mapper: Function to map sample. + :type mapper: callable + :param reader: Given data reader. + :type reader: callable + :param process_num: Number of processes in the pipeline + :type process_num: int + :param buffer_size: Maximal buffer size. + :type buffer_size: int + :param order: Reserve the order of samples from the given reader. + :type order: bool + :return: The wrappered reader + :rtype: callable + """ + end_flag = XmapEndSignal() + + # define a worker to read samples from reader to in_queue + def read_worker(reader, in_queue): + for sample in reader(): + in_queue.put(sample) + in_queue.put(end_flag) + + # define a worker to read samples from reader to in_queue with order flag + def order_read_worker(reader, in_queue): + for order_id, sample in enumerate(reader()): + in_queue.put((order_id, sample)) + in_queue.put(end_flag) + + # define a worker to handle samples from in_queue by mapper and put results to out_queue + def handle_worker(in_queue, out_queue, mapper): + sample = in_queue.get() + while not isinstance(sample, XmapEndSignal): + out_queue.put(mapper(sample)) + sample = in_queue.get() + in_queue.put(end_flag) + out_queue.put(end_flag) + + # define a worker to handle samples from in_queue by mapper and put results to out_queue with order + def order_handle_worker(in_queue, out_queue, mapper, out_order): + ins = in_queue.get() + while not isinstance(ins, XmapEndSignal): + order_id, sample = ins + result = mapper(sample) + while order_id != out_order[0]: + time.sleep(0.001) + out_queue.put(result) + out_order[0] += 1 + ins = in_queue.get() + in_queue.put(end_flag) + out_queue.put(end_flag) + + def xreader(): + # prepare shared memory + manager = Manager() + in_queue = manager.Queue(buffer_size) + out_queue = manager.Queue(buffer_size) + out_order = manager.list([0]) + + # start a read worker in a process + target = order_read_worker if order else read_worker + p = Process(target=target, args=(reader, in_queue)) + p.start() + + # start handle_workers with multiple processes + target = order_handle_worker if order else handle_worker + args = (in_queue, out_queue, mapper, out_order) if order else ( + in_queue, out_queue, mapper) + workers = [ + Process(target=target, args=args) for _ in xrange(process_num) + ] + for w in workers: + w.start() + + # get results + sample = out_queue.get() + while not isinstance(sample, XmapEndSignal): + yield sample + sample = out_queue.get() + finish = 1 + while finish < process_num: + sample = out_queue.get() + if isinstance(sample, XmapEndSignal): + finish += 1 + else: + yield sample + + return xreader