diff --git a/python/paddle/utils/image_multiproc.py b/python/paddle/utils/image_multiproc.py index 82df6d6c0c21e9f5e893fdf404811dc90a330bf6..6ce32f7811d6be6864a567cf41bf408f422409a7 100644 --- a/python/paddle/utils/image_multiproc.py +++ b/python/paddle/utils/image_multiproc.py @@ -3,7 +3,8 @@ import numpy as np from PIL import Image from cStringIO import StringIO import multiprocessing -from functools import partial +import functools +import itertools from paddle.utils.image_util import * from paddle.trainer.config_parser import logger @@ -14,10 +15,12 @@ except ImportError: logger.warning("OpenCV2 is not installed, using PIL to prcoess") cv2 = None +__all__ = ["CvTransformer", "PILTransformer", "MultiProcessImageTransformer"] -class CvTransfomer(ImageTransformer): + +class CvTransformer(ImageTransformer): """ - CvTransfomer used python-opencv to process image. + CvTransformer used python-opencv to process image. """ def __init__( @@ -97,9 +100,9 @@ class CvTransfomer(ImageTransformer): return self.transform(im) -class PILTransfomer(ImageTransformer): +class PILTransformer(ImageTransformer): """ - PILTransfomer used PIL to process image. + PILTransformer used PIL to process image. """ def __init__( @@ -170,8 +173,11 @@ class PILTransfomer(ImageTransformer): return self.transform(im) -def warpper(cls, (dat, label)): - return cls.job(dat, label) +def job(is_img_string, transformer, (data, label)): + if is_img_string: + return transformer.transform_from_string(data), label + else: + return transformer.transform_from_file(data), label class MultiProcessImageTransformer(object): @@ -238,36 +244,19 @@ class MultiProcessImageTransformer(object): :type is_img_string: bool. """ + self.procnum = procnum self.pool = multiprocessing.Pool(procnum) self.is_img_string = is_img_string if cv2 is not None: - self.transformer = CvTransfomer(resize_size, crop_size, transpose, - channel_swap, mean, is_train, - is_color) - else: - self.transformer = PILTransfomer(resize_size, crop_size, transpose, + self.transformer = CvTransformer(resize_size, crop_size, transpose, channel_swap, mean, is_train, is_color) - - def run(self, data, label): - try: - fun = partial(warpper, self) - return self.pool.imap_unordered(fun, zip(data, label), chunksize=5) - except KeyboardInterrupt: - self.pool.terminate() - except Exception, e: - self.pool.terminate() - - def job(self, data, label): - if self.is_img_string: - return self.transformer.transform_from_string(data), label else: - return self.transformer.transform_from_file(data), label - - def __getstate__(self): - self_dict = self.__dict__.copy() - del self_dict['pool'] - return self_dict + self.transformer = PILTransformer(resize_size, crop_size, transpose, + channel_swap, mean, is_train, + is_color) - def __setstate__(self, state): - self.__dict__.update(state) + def run(self, data, label): + fun = functools.partial(job, self.is_img_string, self.transformer) + return self.pool.imap_unordered( + fun, itertools.izip(data, label), chunksize=100 * self.procnum)