提交 84d47ac2 编写于 作者: D dangqingqing

follow comments

上级 9d2f49c6
...@@ -3,7 +3,8 @@ import numpy as np ...@@ -3,7 +3,8 @@ import numpy as np
from PIL import Image from PIL import Image
from cStringIO import StringIO from cStringIO import StringIO
import multiprocessing import multiprocessing
from functools import partial import functools
import itertools
from paddle.utils.image_util import * from paddle.utils.image_util import *
from paddle.trainer.config_parser import logger from paddle.trainer.config_parser import logger
...@@ -14,10 +15,12 @@ except ImportError: ...@@ -14,10 +15,12 @@ except ImportError:
logger.warning("OpenCV2 is not installed, using PIL to prcoess") logger.warning("OpenCV2 is not installed, using PIL to prcoess")
cv2 = None cv2 = None
__all__ = ["CvTransformer", "PILTransformer", "MultiProcessImageTransformer"]
class CvTransfomer(ImageTransformer):
class CvTransformer(ImageTransformer):
""" """
CvTransfomer used python-opencv to process image. CvTransformer used python-opencv to process image.
""" """
def __init__( def __init__(
...@@ -97,9 +100,9 @@ class CvTransfomer(ImageTransformer): ...@@ -97,9 +100,9 @@ class CvTransfomer(ImageTransformer):
return self.transform(im) return self.transform(im)
class PILTransfomer(ImageTransformer): class PILTransformer(ImageTransformer):
""" """
PILTransfomer used PIL to process image. PILTransformer used PIL to process image.
""" """
def __init__( def __init__(
...@@ -170,8 +173,11 @@ class PILTransfomer(ImageTransformer): ...@@ -170,8 +173,11 @@ class PILTransfomer(ImageTransformer):
return self.transform(im) return self.transform(im)
def warpper(cls, (dat, label)): def job(is_img_string, transformer, (data, label)):
return cls.job(dat, label) if is_img_string:
return transformer.transform_from_string(data), label
else:
return transformer.transform_from_file(data), label
class MultiProcessImageTransformer(object): class MultiProcessImageTransformer(object):
...@@ -238,36 +244,19 @@ class MultiProcessImageTransformer(object): ...@@ -238,36 +244,19 @@ class MultiProcessImageTransformer(object):
:type is_img_string: bool. :type is_img_string: bool.
""" """
self.procnum = procnum
self.pool = multiprocessing.Pool(procnum) self.pool = multiprocessing.Pool(procnum)
self.is_img_string = is_img_string self.is_img_string = is_img_string
if cv2 is not None: if cv2 is not None:
self.transformer = CvTransfomer(resize_size, crop_size, transpose, self.transformer = CvTransformer(resize_size, crop_size, transpose,
channel_swap, mean, is_train, channel_swap, mean, is_train,
is_color) is_color)
else: else:
self.transformer = PILTransfomer(resize_size, crop_size, transpose, self.transformer = PILTransformer(resize_size, crop_size, transpose,
channel_swap, mean, is_train, channel_swap, mean, is_train,
is_color) is_color)
def run(self, data, label): def run(self, data, label):
try: fun = functools.partial(job, self.is_img_string, self.transformer)
fun = partial(warpper, self) return self.pool.imap_unordered(
return self.pool.imap_unordered(fun, zip(data, label), chunksize=5) fun, itertools.izip(data, label), chunksize=100 * self.procnum)
except KeyboardInterrupt:
self.pool.terminate()
except Exception, e:
self.pool.terminate()
def job(self, data, label):
if self.is_img_string:
return self.transformer.transform_from_string(data), label
else:
return self.transformer.transform_from_file(data), label
def __getstate__(self):
self_dict = self.__dict__.copy()
del self_dict['pool']
return self_dict
def __setstate__(self, state):
self.__dict__.update(state)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册