提交 05204af1 编写于 作者: Q qingqing01 提交者: Yu Yang

Fix memory leak in image classification demo, which is caused by dataprovider (#323)

* the memory leak is inside one pass.
上级 bd50f93e
...@@ -5,3 +5,5 @@ plot.png ...@@ -5,3 +5,5 @@ plot.png
train.log train.log
image_provider_copy_1.py image_provider_copy_1.py
*pyc *pyc
train.list
test.list
文件模式从 100644 更改为 100755
...@@ -58,24 +58,29 @@ def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg, ...@@ -58,24 +58,29 @@ def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg,
settings.logger.info('DataProvider Initialization finished') settings.logger.info('DataProvider Initialization finished')
@provider(init_hook=hook) @provider(init_hook=hook, min_pool_size=0)
def processData(settings, file_name): def processData(settings, file_list):
""" """
The main function for loading data. The main function for loading data.
Load the batch, iterate all the images and labels in this batch. Load the batch, iterate all the images and labels in this batch.
file_name: the batch file name. file_list: the batch file list.
""" """
data = cPickle.load(io.open(file_name, 'rb')) with open(file_list, 'r') as fdata:
indexes = list(range(len(data['images']))) lines = [line.strip() for line in fdata]
if settings.is_train: random.shuffle(lines)
random.shuffle(indexes) for file_name in lines:
for i in indexes: with io.open(file_name.strip(), 'rb') as file:
if settings.use_jpeg == 1: data = cPickle.load(file)
img = image_util.decode_jpeg(data['images'][i]) indexes = list(range(len(data['images'])))
else: if settings.is_train:
img = data['images'][i] random.shuffle(indexes)
img_feat = image_util.preprocess_img(img, settings.img_mean, for i in indexes:
settings.img_size, settings.is_train, if settings.use_jpeg == 1:
settings.color) img = image_util.decode_jpeg(data['images'][i])
label = data['labels'][i] else:
yield img_feat.tolist(), int(label) img = data['images'][i]
img_feat = image_util.preprocess_img(img, settings.img_mean,
settings.img_size, settings.is_train,
settings.color)
label = data['labels'][i]
yield img_feat.astype('float32'), int(label)
...@@ -35,6 +35,8 @@ if __name__ == '__main__': ...@@ -35,6 +35,8 @@ if __name__ == '__main__':
data_creator = ImageClassificationDatasetCreater(data_dir, data_creator = ImageClassificationDatasetCreater(data_dir,
processed_image_size, processed_image_size,
color) color)
data_creator.train_list_name = "train.txt"
data_creator.test_list_name = "test.txt"
data_creator.num_per_batch = 1000 data_creator.num_per_batch = 1000
data_creator.overwrite = True data_creator.overwrite = True
data_creator.create_batches() data_creator.create_batches()
...@@ -17,3 +17,6 @@ set -e ...@@ -17,3 +17,6 @@ set -e
data_dir=./data/cifar-out data_dir=./data/cifar-out
python preprocess.py -i $data_dir -s 32 -c 1 python preprocess.py -i $data_dir -s 32 -c 1
echo "data/cifar-out/batches/train.txt" > train.list
echo "data/cifar-out/batches/test.txt" > test.list
...@@ -25,8 +25,8 @@ if not is_predict: ...@@ -25,8 +25,8 @@ if not is_predict:
'img_size': 32,'num_classes': 10, 'img_size': 32,'num_classes': 10,
'use_jpeg': 1,'color': "color"} 'use_jpeg': 1,'color': "color"}
define_py_data_sources2(train_list=data_dir+"train.list", define_py_data_sources2(train_list="train.list",
test_list=data_dir+'test.list', test_list="train.list",
module='image_provider', module='image_provider',
obj='processData', obj='processData',
args=args) args=args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册