提交 023d77db 编写于 作者: A Aston Zhang

revise ssd, fcn code

上级 edc21426
......@@ -165,11 +165,10 @@ gb.train(train_iter, test_iter, net, loss, trainer, ctx, num_epochs=5)
预测一张新图像时,我们只需要将其归一化并转成卷积网络需要的4D格式。
```{.python .input n=13}
def predict(im):
data = test_iter._dataset.normalize_image(im)
data = data.transpose((2, 0, 1)).expand_dims(axis=0)
yhat = net(data.as_in_context(ctx[0]))
pred = nd.argmax(yhat, axis=1)
def predict(img):
x = test_iter._dataset.normalize_image(img)
x = x.transpose((2, 0, 1)).expand_dims(axis=0)
pred = nd.argmax(net(x.as_in_context(ctx[0])), axis=1)
return pred.reshape((pred.shape[1], pred.shape[2]))
```
......@@ -185,7 +184,7 @@ def label2image(pred):
现在我们读取前几张测试图像并对其进行预测。
```{.python .input n=15}
test_images, test_labels = gb.read_voc_images(train=False)
test_images, test_labels = gb.read_voc_images(is_train=False)
n = 5
imgs = []
......
......@@ -50,25 +50,26 @@ voc_dir = download_voc_pascal()
```{.python .input n=3}
# 本函数已保存在 gluonbook 包中方便以后使用。
def read_voc_images(root=voc_dir, train=True):
def read_voc_images(root=voc_dir, is_train=True):
txt_fname = '%s/ImageSets/Segmentation/%s' % (
root, 'train.txt' if train else 'val.txt')
root, 'train.txt' if is_train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
data, label = [None] * len(images), [None] * len(images)
features, labels = [None] * len(images), [None] * len(images)
for i, fname in enumerate(images):
data[i] = image.imread('%s/JPEGImages/%s.jpg' % (root, fname))
label[i] = image.imread('%s/SegmentationClass/%s.png' % (root, fname))
return data, label
features[i] = image.imread('%s/JPEGImages/%s.jpg' % (root, fname))
labels[i] = image.imread(
'%s/SegmentationClass/%s.png' % (root, fname))
return features, labels
train_images, train_labels = read_voc_images()
train_features, train_labels = read_voc_images()
```
我们画出前面五张图像和它们对应的标注。在标注,白色代表边框黑色代表背景,其他不同的颜色对应不同目标类别。
```{.python .input n=4}
n = 5
imgs = train_images[0:n] + train_labels[0:n]
imgs = train_features[0:n] + train_labels[0:n]
gb.show_images(imgs, 2, n);
```
......@@ -93,13 +94,14 @@ VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
```{.python .input n=6}
colormap2label = nd.zeros(256 ** 3)
for i, cm in enumerate(VOC_COLORMAP):
colormap2label[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
for i, colormap in enumerate(VOC_COLORMAP):
colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
# 本函数已保存在 gluonbook 包中方便以后使用。
def voc_label_indices(img, colormap2label):
data = img.astype('int32')
idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
def voc_label_indices(colormap, colormap2label):
colormap = colormap.astype('int32')
idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
+ colormap[:, :, 2])
return colormap2label[idx]
```
......@@ -118,14 +120,14 @@ y[105:115, 130:140], VOC_CLASSES[1]
```{.python .input n=8}
# 本函数已保存在 gluonbook 包中方便以后使用。
def voc_rand_crop(data, label, height, width):
data, rect = image.random_crop(data, (width, height))
def voc_rand_crop(feature, label, height, width):
feature, rect = image.random_crop(feature, (width, height))
label = image.fixed_crop(label, *rect)
return data, label
return feature, label
imgs = []
for _ in range(n):
imgs += voc_rand_crop(train_images[0], train_labels[0], 200, 300)
imgs += voc_rand_crop(train_features[0], train_labels[0], 200, 300)
gb.show_images(imgs[::2] + imgs[1::2], 2, n);
```
......@@ -136,32 +138,33 @@ gb.show_images(imgs[::2] + imgs[1::2], 2, n);
```{.python .input n=9}
# 本类已保存在 gluonbook 包中方便以后使用。
class VOCSegDataset(gdata.Dataset):
def __init__(self, train, crop_size, voc_dir, colormap2label):
def __init__(self, is_train, crop_size, voc_dir, colormap2label):
self.rgb_mean = nd.array([0.485, 0.456, 0.406])
self.rgb_std = nd.array([0.229, 0.224, 0.225])
self.crop_size = crop_size
data, label = read_voc_images(root=voc_dir, train=train)
self.data = [self.normalize_image(im) for im in self.filter(data)]
self.label = self.filter(label)
features, labels = read_voc_images(root=voc_dir, is_train=is_train)
self.features = [self.normalize_image(feature)
for feature in self.filter(features)]
self.labels = self.filter(labels)
self.colormap2label = colormap2label
print('read ' + str(len(self.data)) + ' examples')
print('read ' + str(len(self.features)) + ' examples')
def normalize_image(self, data):
return (data.astype('float32') / 255 - self.rgb_mean) / self.rgb_std
def normalize_image(self, img):
return (img.astype('float32') / 255 - self.rgb_mean) / self.rgb_std
def filter(self, images):
return [im for im in images if (
im.shape[0] >= self.crop_size[0] and
im.shape[1] >= self.crop_size[1])]
def filter(self, imgs):
return [img for img in imgs if (
img.shape[0] >= self.crop_size[0] and
img.shape[1] >= self.crop_size[1])]
def __getitem__(self, idx):
data, label = voc_rand_crop(self.data[idx], self.label[idx],
*self.crop_size)
return (data.transpose((2, 0, 1)),
feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
*self.crop_size)
return (feature.transpose((2, 0, 1)),
voc_label_indices(label, self.colormap2label))
def __len__(self):
return len(self.data)
return len(self.features)
```
假设我们裁剪$320\times 480$图像来进行训练,我们可以查看训练和测试各保留了多少图像。
......
......@@ -277,8 +277,8 @@ for epoch in range(20):
```{.python .input n=20}
def process_image(file_name):
img = image.imread(file_name)
data = image.imresize(img, 256, 256).astype('float32')
return data.transpose((2, 0, 1)).expand_dims(axis=0), img
feature = image.imresize(img, 256, 256).astype('float32')
return feature.transpose((2, 0, 1)).expand_dims(axis=0), img
x, img = process_image('../img/pikachu.jpg')
```
......@@ -310,7 +310,7 @@ def display(img, out, threshold=0.5):
bbox = [row[2:6] * nd.array(img.shape[0:2] * 2, ctx=row.context)]
gb.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')
display(img, out, threshold=0.01)
display(img, out, threshold=0.3)
```
## 小结
......
......@@ -352,17 +352,18 @@ def read_imdb(folder='train'):
return data
def read_voc_images(root='../data/VOCdevkit/VOC2012', train=True):
def read_voc_images(root='../data/VOCdevkit/VOC2012', is_train=True):
"""Read VOC images."""
txt_fname = '%s/ImageSets/Segmentation/%s' % (
root, 'train.txt' if train else 'val.txt')
root, 'train.txt' if is_train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
data, label = [None] * len(images), [None] * len(images)
features, labels = [None] * len(images), [None] * len(images)
for i, fname in enumerate(images):
data[i] = image.imread('%s/JPEGImages/%s.jpg' % (root, fname))
label[i] = image.imread('%s/SegmentationClass/%s.png' % (root, fname))
return data, label
features[i] = image.imread('%s/JPEGImages/%s.jpg' % (root, fname))
labels[i] = image.imread(
'%s/SegmentationClass/%s.png' % (root, fname))
return features, labels
class Residual(nn.Block):
......@@ -775,29 +776,30 @@ def use_svg_display():
display.set_matplotlib_formats('svg')
def voc_label_indices(img, colormap2label):
def voc_label_indices(colormap, colormap2label):
"""Assig label indices for Pascal VOC2012 Dataset."""
data = img.astype('int32')
idx = (data[:,:,0] * 256 + data[:,:,1]) * 256 + data[:,:,2]
colormap = colormap.astype('int32')
idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
+ colormap[:, :, 2])
return colormap2label[idx]
def voc_rand_crop(data, label, height, width):
def voc_rand_crop(feature, label, height, width):
"""Random cropping for images of the Pascal VOC2012 Dataset."""
data, rect = image.random_crop(data, (width, height))
feature, rect = image.random_crop(feature, (width, height))
label = image.fixed_crop(label, *rect)
return data, label
return feature, label
class VOCSegDataset(gdata.Dataset):
"""The Pascal VOC2012 Dataset."""
def __init__(self, train, crop_size, voc_dir, colormap2label):
def __init__(self, is_train, crop_size, voc_dir, colormap2label):
self.rgb_mean = nd.array([0.485, 0.456, 0.406])
self.rgb_std = nd.array([0.229, 0.224, 0.225])
self.crop_size = crop_size
data, label = read_voc_images(root=voc_dir, train=train)
data, labels = read_voc_images(root=voc_dir, is_train=is_train)
self.data = [self.normalize_image(im) for im in self.filter(data)]
self.label = self.filter(label)
self.labels = self.filter(labels)
self.colormap2label = colormap2label
print('read ' + str(len(self.data)) + ' examples')
......@@ -810,10 +812,10 @@ class VOCSegDataset(gdata.Dataset):
im.shape[1] >= self.crop_size[1])]
def __getitem__(self, idx):
data, label = voc_rand_crop(self.data[idx], self.label[idx],
*self.crop_size)
data, labels = voc_rand_crop(self.data[idx], self.labels[idx],
*self.crop_size)
return (data.transpose((2, 0, 1)),
voc_label_indices(label, self.colormap2label))
voc_label_indices(labels, self.colormap2label))
def __len__(self):
return len(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册