diff --git a/demo/image_classification/.gitignore b/demo/image_classification/.gitignore index 76961dd1436f859f85f75ff9ed7d3fefdec83dc4..6a05b8f6632db0977fceade8b48a89b9f7f6e6cc 100644 --- a/demo/image_classification/.gitignore +++ b/demo/image_classification/.gitignore @@ -5,3 +5,5 @@ plot.png train.log image_provider_copy_1.py *pyc +train.list +test.list diff --git a/demo/image_classification/data/download_cifar.sh b/demo/image_classification/data/download_cifar.sh old mode 100644 new mode 100755 diff --git a/demo/image_classification/image_provider.py b/demo/image_classification/image_provider.py index 9e2f8b8949b39b930680e6d84758133eed566881..305efbcdc6bb11f1dac65cc3af82fb997db97f27 100644 --- a/demo/image_classification/image_provider.py +++ b/demo/image_classification/image_provider.py @@ -58,24 +58,29 @@ def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg, settings.logger.info('DataProvider Initialization finished') -@provider(init_hook=hook) -def processData(settings, file_name): +@provider(init_hook=hook, min_pool_size=0) +def processData(settings, file_list): """ The main function for loading data. Load the batch, iterate all the images and labels in this batch. - file_name: the batch file name. + file_list: the batch file list. """ - data = cPickle.load(io.open(file_name, 'rb')) - indexes = list(range(len(data['images']))) - if settings.is_train: - random.shuffle(indexes) - for i in indexes: - if settings.use_jpeg == 1: - img = image_util.decode_jpeg(data['images'][i]) - else: - img = data['images'][i] - img_feat = image_util.preprocess_img(img, settings.img_mean, - settings.img_size, settings.is_train, - settings.color) - label = data['labels'][i] - yield img_feat.tolist(), int(label) + with open(file_list, 'r') as fdata: + lines = [line.strip() for line in fdata] + random.shuffle(lines) + for file_name in lines: + with io.open(file_name.strip(), 'rb') as file: + data = cPickle.load(file) + indexes = list(range(len(data['images']))) + if settings.is_train: + random.shuffle(indexes) + for i in indexes: + if settings.use_jpeg == 1: + img = image_util.decode_jpeg(data['images'][i]) + else: + img = data['images'][i] + img_feat = image_util.preprocess_img(img, settings.img_mean, + settings.img_size, settings.is_train, + settings.color) + label = data['labels'][i] + yield img_feat.astype('float32'), int(label) diff --git a/demo/image_classification/preprocess.py b/demo/image_classification/preprocess.py index 0286a5d7e9dc8d0f546b18b1ed846c9452cdbe4b..fe7ea19bf02776629dff0f64f5b671dc457eae64 100755 --- a/demo/image_classification/preprocess.py +++ b/demo/image_classification/preprocess.py @@ -35,6 +35,8 @@ if __name__ == '__main__': data_creator = ImageClassificationDatasetCreater(data_dir, processed_image_size, color) + data_creator.train_list_name = "train.txt" + data_creator.test_list_name = "test.txt" data_creator.num_per_batch = 1000 data_creator.overwrite = True data_creator.create_batches() diff --git a/demo/image_classification/preprocess.sh b/demo/image_classification/preprocess.sh index dfe3eb95d1ab8b2114fcf5e0f461ea0efb7cc1e5..e3e86ff10675c0622867af2eb0d26c87f4bc2db5 100755 --- a/demo/image_classification/preprocess.sh +++ b/demo/image_classification/preprocess.sh @@ -17,3 +17,6 @@ set -e data_dir=./data/cifar-out python preprocess.py -i $data_dir -s 32 -c 1 + +echo "data/cifar-out/batches/train.txt" > train.list +echo "data/cifar-out/batches/test.txt" > test.list diff --git a/demo/image_classification/vgg_16_cifar.py b/demo/image_classification/vgg_16_cifar.py index e8b8af4bd313d0738aafab8da93fc510e40cc3d6..edd6988c48acd6b554e09b721c37b291e21f46eb 100755 --- a/demo/image_classification/vgg_16_cifar.py +++ b/demo/image_classification/vgg_16_cifar.py @@ -25,8 +25,8 @@ if not is_predict: 'img_size': 32,'num_classes': 10, 'use_jpeg': 1,'color': "color"} - define_py_data_sources2(train_list=data_dir+"train.list", - test_list=data_dir+'test.list', + define_py_data_sources2(train_list="train.list", + test_list="train.list", module='image_provider', obj='processData', args=args)