提交 3dc6dedf 编写于 作者: M muli

tiny fix for fcn

上级 d97b0091
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
```{.python .input n=1} ```{.python .input n=1}
import os import os
import tarfile import tarfile
from mxnet import gluon from mxnet import gluon
data_root = '../data' data_root = '../data'
...@@ -65,7 +65,7 @@ train_images, train_labels = read_images() ...@@ -65,7 +65,7 @@ train_images, train_labels = read_images()
imgs = [] imgs = []
for i in range(3): for i in range(3):
imgs += [train_images[i], train_labels[i]] imgs += [train_images[i], train_labels[i]]
utils.show_images(imgs, nrows=3, ncols=2, figsize=(12,8)) utils.show_images(imgs, nrows=3, ncols=2, figsize=(12,8))
[im.shape for im in imgs] [im.shape for im in imgs]
``` ```
...@@ -82,9 +82,9 @@ def rand_crop(data, label, height, width): ...@@ -82,9 +82,9 @@ def rand_crop(data, label, height, width):
imgs = [] imgs = []
for _ in range(3): for _ in range(3):
imgs += rand_crop(train_images[0], train_labels[0], imgs += rand_crop(train_images[0], train_labels[0],
200, 300) 200, 300)
utils.show_images(imgs, nrows=3, ncols=2, figsize=(12,8)) utils.show_images(imgs, nrows=3, ncols=2, figsize=(12,8))
``` ```
...@@ -96,7 +96,7 @@ classes = ['background','aeroplane','bicycle','bird','boat', ...@@ -96,7 +96,7 @@ classes = ['background','aeroplane','bicycle','bird','boat',
'dog','horse','motorbike','person','potted plant', 'dog','horse','motorbike','person','potted plant',
'sheep','sofa','train','tv/monitor'] 'sheep','sofa','train','tv/monitor']
# RGB color for each class # RGB color for each class
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128], colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
[128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0], [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
[64,128,0],[192,128,0],[64,0,128],[192,0,128], [64,128,0],[192,128,0],[64,0,128],[192,0,128],
[64,128,128],[192,128,128],[0,64,0],[128,64,0], [64,128,128],[192,128,128],[0,64,0],[128,64,0],
...@@ -112,9 +112,9 @@ import numpy as np ...@@ -112,9 +112,9 @@ import numpy as np
from mxnet import nd from mxnet import nd
cm2lbl = np.zeros(256**3) cm2lbl = np.zeros(256**3)
for i,cm in enumerate(colormap): for i,cm in enumerate(colormap):
cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i
def image2label(im): def image2label(im):
data = im.astype('int32').asnumpy() data = im.astype('int32').asnumpy()
idx = (data[:,:,0]*256+data[:,:,1])*256+data[:,:,2] idx = (data[:,:,0]*256+data[:,:,1])*256+data[:,:,2]
...@@ -142,12 +142,12 @@ def normalize_image(data): ...@@ -142,12 +142,12 @@ def normalize_image(data):
return (data.astype('float32') / 255 - rgb_mean) / rgb_std return (data.astype('float32') / 255 - rgb_mean) / rgb_std
class VOCSegDataset(gluon.data.Dataset): class VOCSegDataset(gluon.data.Dataset):
def _filter(self, images): def _filter(self, images):
return [im for im in images if ( return [im for im in images if (
im.shape[0] >= self.crop_size[0] and im.shape[0] >= self.crop_size[0] and
im.shape[1] >= self.crop_size[1])] im.shape[1] >= self.crop_size[1])]
def __init__(self, train, crop_size): def __init__(self, train, crop_size):
self.crop_size = crop_size self.crop_size = crop_size
data, label = read_images(train=train) data, label = read_images(train=train)
...@@ -155,15 +155,15 @@ class VOCSegDataset(gluon.data.Dataset): ...@@ -155,15 +155,15 @@ class VOCSegDataset(gluon.data.Dataset):
self.data = [normalize_image(im) for im in data] self.data = [normalize_image(im) for im in data]
self.label = self._filter(label) self.label = self._filter(label)
print('Read '+str(len(self.data))+' examples') print('Read '+str(len(self.data))+' examples')
def __getitem__(self, idx): def __getitem__(self, idx):
data, label = rand_crop( data, label = rand_crop(
self.data[idx], self.label[idx], self.data[idx], self.label[idx],
*self.crop_size) *self.crop_size)
data = data.transpose((2,0,1)) data = data.transpose((2,0,1))
label = image2label(label) label = image2label(label)
return data, label return data, label
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
``` ```
...@@ -244,7 +244,7 @@ pretrained_net = models.resnet18_v2(pretrained=True) ...@@ -244,7 +244,7 @@ pretrained_net = models.resnet18_v2(pretrained=True)
net = nn.HybridSequential() net = nn.HybridSequential()
for layer in pretrained_net.features[:-2]: for layer in pretrained_net.features[:-2]:
net.add(layer) net.add(layer)
x = nd.random.uniform(shape=(1,3,*input_shape)) x = nd.random.uniform(shape=(1,3,*input_shape))
print('Input:', x.shape) print('Input:', x.shape)
print('Output:', net(x).shape) print('Output:', net(x).shape)
...@@ -259,7 +259,7 @@ with net.name_scope(): ...@@ -259,7 +259,7 @@ with net.name_scope():
net.add( net.add(
nn.Conv2D(num_classes, kernel_size=1), nn.Conv2D(num_classes, kernel_size=1),
nn.Conv2DTranspose(num_classes, kernel_size=64, padding=16,strides=32) nn.Conv2DTranspose(num_classes, kernel_size=64, padding=16,strides=32)
) )
``` ```
## 训练 ## 训练
...@@ -316,8 +316,8 @@ plt.show() ...@@ -316,8 +316,8 @@ plt.show()
from mxnet import init from mxnet import init
conv_trans = net[-1] conv_trans = net[-1]
conv_trans.initialize(init=init.Xavier()) conv_trans.initialize(init=init.Zero())
net[-2].initialize(init=init.Zero()) net[-2].initialize(init=init.Xavier())
x = nd.zeros((batch_size, 3, *input_shape)) x = nd.zeros((batch_size, 3, *input_shape))
net(x) net(x)
...@@ -334,7 +334,7 @@ import sys ...@@ -334,7 +334,7 @@ import sys
sys.path.append('..') sys.path.append('..')
import utils import utils
loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=1) loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)
ctx = utils.try_all_gpus() ctx = utils.try_all_gpus()
net.collect_params().reset_ctx(ctx) net.collect_params().reset_ctx(ctx)
...@@ -375,7 +375,7 @@ for i in range(n): ...@@ -375,7 +375,7 @@ for i in range(n):
x = test_images[i] x = test_images[i]
pred = label2image(predict(x)) pred = label2image(predict(x))
imgs += [x, pred, test_labels[i]] imgs += [x, pred, test_labels[i]]
utils.show_images(imgs, nrows=n, ncols=3, figsize=(6,10)) utils.show_images(imgs, nrows=n, ncols=3, figsize=(6,10))
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册