提交 a02b8f80 编写于 作者: Y yangyaming

Add clean callback.

上级 a0d1146b
...@@ -11,6 +11,7 @@ import multiprocessing ...@@ -11,6 +11,7 @@ import multiprocessing
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
from threading import local from threading import local
import atexit
from data_utils.utility import read_manifest from data_utils.utility import read_manifest
from data_utils.utility import xmap_readers_mp from data_utils.utility import xmap_readers_mp
from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.augmentor.augmentation import AugmentationPipeline
...@@ -274,13 +275,18 @@ class DataGenerator(object): ...@@ -274,13 +275,18 @@ class DataGenerator(object):
for instance in manifest: for instance in manifest:
yield instance yield instance
return xmap_readers_mp( reader, cleanup_callback = xmap_readers_mp(
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]), lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
reader, reader,
self._num_threads, self._num_threads,
4096, 4096,
order=True) order=True)
# register callback to main process
atexit.register(cleanup_callback)
return reader
def _padding_batch(self, batch, padding_to=-1, flatten=False): def _padding_batch(self, batch, padding_to=-1, flatten=False):
""" """
Padding audio features with zeros to make them have the same shape (or Padding audio features with zeros to make them have the same shape (or
......
...@@ -138,6 +138,10 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): ...@@ -138,6 +138,10 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
out_queue.put(sample) out_queue.put(sample)
out_queue.put(end_flag) out_queue.put(end_flag)
def cleanup():
# kill all sub process and threads
os._exit(0)
def xreader(): def xreader():
# prepare shared memory # prepare shared memory
manager = Manager() manager = Manager()
...@@ -174,4 +178,4 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): ...@@ -174,4 +178,4 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
yield sample yield sample
sample = flush_queue.get() sample = flush_queue.get()
return xreader return xreader, cleanup
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册