提交 bdffa40e 编写于 作者: W wwhu

add xmap for image list and modify the image reader of infer.py

上级 e9b94cab
......@@ -147,11 +147,11 @@ dataset_100/train_images/n02643566_75.jpeg 8
```python
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.test_reader('train.list'),
reader.train_reader('train.list'),
buf_size=1000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
reader.train_reader('val.list'),
reader.test_reader('val.list'),
batch_size=BATCH_SIZE)
```
......@@ -209,24 +209,10 @@ trainer.train(
with gzip.open('params_pass_10.tar.gz', 'r') as f:
parameters = paddle.parameters.Parameters.from_tar(f)
def load_image(file):
im = Image.open(file)
im = im.resize((224, 224), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
# The storage order of the loaded image is W(widht),
# H(height), C(channel). PaddlePaddle requires
# the CHW order, so transpose them.
im = im.transpose((2, 0, 1)) # CHW
# In the training phase, the channel order of CIFAR
# image is B(Blue), G(green), R(Red). But PIL open
# image in RGB mode. It must swap the channel order.
im = im[(2, 1, 0), :, :] # BGR
im = im.flatten()
im = im / 255.0
return im
file_list = [line.strip() for line in open(image_list_file)]
test_data = [(load_image(image_file),) for image_file in file_list]
test_data = [(paddle.image.load_and_transform(image_file, 256, 224, False)
.flatten().astype('float32'), )
for image_file in file_list]
probs = paddle.infer(
output_layer=out, parameters=parameters, input=test_data)
lab = np.argsort(-probs)
......@@ -234,4 +220,4 @@ for file_name, result in zip(file_list, lab):
print "Label of %s is: %d" % (file_name, result[0])
```
首先从文件中加载训练好的模型(代码里以第10轮迭代的结果为例),然后读取`image_list_file`中的图像。`image_list_file`是一个文本文件,每一行为一个图像路径。`load_image`是一个加载图像的函数。代码使用`paddle.infer`判断`image_list_file`中每个图像的类别,并进行输出。
首先从文件中加载训练好的模型(代码里以第10轮迭代的结果为例),然后读取`image_list_file`中的图像。`image_list_file`是一个文本文件,每一行为一个图像路径。代码使用`paddle.infer`判断`image_list_file`中每个图像的类别,并进行输出。
......@@ -3,7 +3,7 @@ import paddle.v2 as paddle
__all__ = ['alexnet']
def alexnet(input, class_dim=100):
def alexnet(input, class_dim):
conv1 = paddle.layer.img_conv(
input=input,
filter_size=11,
......
......@@ -3,8 +3,8 @@ import paddle.v2 as paddle
__all__ = ['googlenet']
def inception2(name, input, channels, filter1, filter3R, filter3, filter5R,
filter5, proj):
def inception(name, input, channels, filter1, filter3R, filter3, filter5R,
filter5, proj):
cov1 = paddle.layer.img_conv(
name=name + '_1',
input=input,
......@@ -65,7 +65,7 @@ def inception2(name, input, channels, filter1, filter3R, filter3, filter5R,
return cat
def googlenet(input, class_dim=100):
def googlenet(input, class_dim):
# stage 1
conv1 = paddle.layer.img_conv(
name="conv1",
......@@ -97,23 +97,23 @@ def googlenet(input, class_dim=100):
name="pool2", input=conv2_2, pool_size=3, num_channels=192, stride=2)
# stage 3
ince3a = inception2("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32)
ince3b = inception2("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64)
ince3a = inception("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32)
ince3b = inception("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64)
pool3 = paddle.layer.img_pool(
name="pool3", input=ince3b, num_channels=480, pool_size=3, stride=2)
# stage 4
ince4a = inception2("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64)
ince4b = inception2("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64)
ince4c = inception2("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64)
ince4d = inception2("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64)
ince4e = inception2("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128)
ince4a = inception("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64)
ince4b = inception("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64)
ince4c = inception("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64)
ince4d = inception("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64)
ince4e = inception("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128)
pool4 = paddle.layer.img_pool(
name="pool4", input=ince4e, num_channels=832, pool_size=3, stride=2)
# stage 5
ince5a = inception2("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128)
ince5b = inception2("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128)
ince5a = inception("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128)
ince5b = inception("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128)
pool5 = paddle.layer.img_pool(
name="pool5",
input=ince5b,
......
......@@ -54,24 +54,9 @@ def main():
with gzip.open(args.params_path, 'r') as f:
parameters = paddle.parameters.Parameters.from_tar(f)
def load_image(file):
im = Image.open(file)
im = im.resize((WIDTH, HEIGHT), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
# The storage order of the loaded image is W(widht),
# H(height), C(channel). PaddlePaddle requires
# the CHW order, so transpose them.
im = im.transpose((2, 0, 1)) # CHW
# In the training phase, the channel order of CIFAR
# image is B(Blue), G(green), R(Red). But PIL open
# image in RGB mode. It must swap the channel order.
im = im[(2, 1, 0), :, :] # BGR
im = im.flatten()
im = im / 255.0
return im
file_list = [line.strip() for line in open(args.data_list)]
test_data = [(load_image(image_file), ) for image_file in file_list]
test_data = [(paddle.image.load_and_transform(image_file, 256, 224, False)
.flatten().astype('float32'), ) for image_file in file_list]
probs = paddle.infer(
output_layer=out, parameters=parameters, input=test_data)
lab = np.argsort(-probs)
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import random
from paddle.v2.image import load_and_transform
import paddle.v2 as paddle
from multiprocessing import cpu_count
def train_mapper(sample):
'''
map image path to type needed by model input layer for the training set
'''
img, label = sample
img = paddle.image.load_image(img)
img = paddle.image.simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label
def test_mapper(sample):
'''
map image path to type needed by model input layer for the test set
'''
img, label = sample
img = paddle.image.load_image(img)
img = paddle.image.simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label
def train_reader(train_list):
def train_reader(train_list, buffered_size=1024):
def reader():
with open(train_list, 'r') as f:
lines = [line.strip() for line in f]
random.shuffle(lines)
for line in lines:
img_path, lab = line.strip().split('\t')
im = load_and_transform(img_path, 256, 224, True)
yield im.flatten().astype('float32'), int(lab)
yield img_path, int(lab)
return reader
return paddle.reader.xmap_readers(train_mapper, reader,
cpu_count(), buffered_size)
def test_reader(test_list):
def test_reader(test_list, buffered_size=1024):
def reader():
with open(test_list, 'r') as f:
lines = [line.strip() for line in f]
for line in lines:
img_path, lab = line.strip().split('\t')
im = load_and_transform(img_path, 256, 224, False)
yield im.flatten().astype('float32'), int(lab)
yield img_path, int(lab)
return reader
return paddle.reader.xmap_readers(test_mapper, reader,
cpu_count(), buffered_size)
if __name__ == '__main__':
......
......@@ -55,7 +55,7 @@ def layer_warp(block_func, input, ch_in, ch_out, count, stride):
return conv
def resnet_imagenet(input, depth=50, class_dim=100):
def resnet_imagenet(input, class_dim, depth=50):
cfg = {
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
......@@ -78,7 +78,7 @@ def resnet_imagenet(input, depth=50, class_dim=100):
return out
def resnet_cifar10(input, depth=32, class_dim=10):
def resnet_cifar10(input, class_dim, depth=32):
# depth should be one of 20, 32, 44, 56, 110, 1202
assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
......
......@@ -72,13 +72,13 @@ def main():
paddle.reader.shuffle(
flowers.train(),
# To use other data, replace the above line with:
# reader.test_reader('train.list'),
# reader.train_reader('train.list'),
buf_size=1000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
flowers.valid(),
# To use other data, replace the above line with:
# reader.train_reader('val.list'),
# reader.test_reader('val.list'),
batch_size=BATCH_SIZE)
# End batch and end pass event handler
......
......@@ -17,7 +17,7 @@ import paddle.v2 as paddle
__all__ = ['vgg13', 'vgg16', 'vgg19']
def vgg(input, nums, class_dim=100):
def vgg(input, nums, class_dim):
def conv_block(input, num_filter, groups, num_channels=None):
return paddle.networks.img_conv_group(
input=input,
......@@ -53,16 +53,16 @@ def vgg(input, nums, class_dim=100):
return out
def vgg13(input, class_dim=100):
def vgg13(input, class_dim):
nums = [2, 2, 2, 2, 2]
return vgg(input, nums, class_dim)
def vgg16(input, class_dim=100):
def vgg16(input, class_dim):
nums = [2, 2, 3, 3, 3]
return vgg(input, nums, class_dim)
def vgg19(input, class_dim=100):
def vgg19(input, class_dim):
nums = [2, 2, 4, 4, 4]
return vgg(input, nums, class_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册