utils.py 15.8 KB
Newer Older
1
import random
M
muli 已提交
2
from time import time
3

M
muli 已提交
4
from IPython.display import set_matplotlib_formats
A
Aston Zhang 已提交
5
from matplotlib import pyplot as plt
6 7 8 9 10
import mxnet as mx
from mxnet import autograd, gluon, image, nd
from mxnet.gluon import nn, data as gdata, loss as gloss, utils as gutils
import numpy as np

M
muli 已提交
11 12 13
# set default figure size
set_matplotlib_formats('retina')
plt.rcParams['figure.figsize'] = (3.5, 2.5)
A
Aston Zhang 已提交
14

M
muli 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
class DataLoader(object):
    """similiar to gluon.data.DataLoader, but might be faster.

    The main difference this data loader tries to read more exmaples each
    time. But the limits are 1) all examples in dataset have the same shape, 2)
    data transfomer needs to process multiple examples at each time
    """
    def __init__(self, dataset, batch_size, shuffle, transform=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.transform = transform

    def __iter__(self):
        data = self.dataset[:]
        X = data[0]
        y = nd.array(data[1])
        n = X.shape[0]
        if self.shuffle:
            idx = np.arange(n)
            np.random.shuffle(idx)
            X = nd.array(X.asnumpy()[idx])
            y = nd.array(y.asnumpy()[idx])

        for i in range(n//self.batch_size):
            if self.transform is not None:
M
muli 已提交
41
                yield self.transform(X[i*self.batch_size:(i+1)*self.batch_size],
M
muli 已提交
42 43 44 45 46 47 48 49
                                     y[i*self.batch_size:(i+1)*self.batch_size])
            else:
                yield (X[i*self.batch_size:(i+1)*self.batch_size],
                       y[i*self.batch_size:(i+1)*self.batch_size])

    def __len__(self):
        return len(self.dataset)//self.batch_size

A
Aston Zhang 已提交
50

M
muli 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
    """download the fashion mnist dataest and then load into memory"""
    def transform_mnist(data, label):
        # Transform a batch of examples.
        if resize:
            n = data.shape[0]
            new_data = nd.zeros((n, resize, resize, data.shape[3]))
            for i in range(n):
                new_data[i] = image.imresize(data[i], resize, resize)
            data = new_data
        # change data from batch x height x width x channel to batch x channel x height x width
        return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')

    mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
    mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
M
muli 已提交
66
    # Transform later to avoid memory explosion.
M
muli 已提交
67 68 69 70
    train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
    test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
    return (train_data, test_data)

A
Aston Zhang 已提交
71

M
muli 已提交
72 73 74 75 76 77 78 79 80
def try_gpu():
    """If GPU is available, return mx.gpu(0); else return mx.cpu()"""
    try:
        ctx = mx.gpu()
        _ = nd.array([0], ctx=ctx)
    except:
        ctx = mx.cpu()
    return ctx

A
Aston Zhang 已提交
81

M
muli 已提交
82 83
def try_all_gpus():
    """Return all available GPUs, or [mx.gpu()] if there is no GPU"""
A
Aston Zhang 已提交
84
    ctxes = []
M
muli 已提交
85 86 87 88
    try:
        for i in range(16):
            ctx = mx.gpu(i)
            _ = nd.array([0], ctx=ctx)
A
Aston Zhang 已提交
89
            ctxes.append(ctx)
M
muli 已提交
90 91
    except:
        pass
A
Aston Zhang 已提交
92 93 94 95
    if not ctxes:
        ctxes = [mx.cpu()]
    return ctxes

M
muli 已提交
96 97

def sgd(params, lr, batch_size):
98
    """Mini-batch stochastic gradient descent."""
M
muli 已提交
99 100 101
    for param in params:
        param[:] = param - lr * param.grad / batch_size

A
Aston Zhang 已提交
102

A
softmax  
Aston Zhang 已提交
103
def accuracy(y_hat, y):
104
    """Get accuracy."""
A
softmax  
Aston Zhang 已提交
105
    return (y_hat.argmax(axis=1) == y).mean().asscalar()
M
muli 已提交
106

A
Aston Zhang 已提交
107

M
muli 已提交
108
def _get_batch(batch, ctx):
A
Aston Zhang 已提交
109
    """return features and labels on ctx"""
M
muli 已提交
110
    if isinstance(batch, mx.io.DataBatch):
A
Aston Zhang 已提交
111 112
        features = batch.data[0]
        labels = batch.label[0]
M
muli 已提交
113
    else:
A
Aston Zhang 已提交
114
        features, labels = batch
M
muli 已提交
115 116
    if labels.dtype != features.dtype:
        labels = labels.astype(features.dtype)
A
Aston Zhang 已提交
117 118 119
    return (gutils.split_and_load(features, ctx),
            gutils.split_and_load(labels, ctx),
            features.shape[0])
M
muli 已提交
120

121 122 123

def evaluate_accuracy(data_iter, net, ctx=[mx.cpu()]):
    """Evaluate accuracy of a model on the given data set."""
M
muli 已提交
124 125 126
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    acc = nd.array([0])
127 128 129 130 131 132
    n = 0
    if isinstance(data_iter, mx.io.MXDataIter):
        data_iter.reset()
    for batch in data_iter:
        features, labels, batch_size = _get_batch(batch, ctx)
        for X, y in zip(features, labels):
M
muli 已提交
133
            y = y.astype('float32')
134
            acc += (net(X).argmax(axis=1)==y).sum().copyto(mx.cpu())
M
muli 已提交
135
            n += y.size
M
muli 已提交
136
        acc.wait_to_read()
M
muli 已提交
137 138 139
    return acc.asscalar() / n


140 141
def train_cpu(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, trainer=None):
A
softmax  
Aston Zhang 已提交
142
    """Train and evaluate a model on CPU."""
M
muli 已提交
143 144 145
    for epoch in range(1, num_epochs + 1):
        train_l_sum = 0
        train_acc_sum = 0
146 147 148 149 150 151 152 153 154 155 156
        for X, y in train_iter:
            with autograd.record():
                y_hat = net(X)
                l = loss(y_hat, y)
            l.backward()
            if trainer is None:
                sgd(params, lr, batch_size)
            else:
                trainer.step(batch_size)
            train_l_sum += l.mean().asscalar()
            train_acc_sum += accuracy(y_hat, y)
A
softmax  
Aston Zhang 已提交
157
        test_acc = evaluate_accuracy(test_iter, net)
M
muli 已提交
158
        print("epoch %d, loss %.4f, train acc %.3f, test acc %.3f"
159 160 161 162
              % (epoch, train_l_sum / len(train_iter),
                 train_acc_sum / len(train_iter), test_acc))


A
Aston Zhang 已提交
163
def train(train_iter, test_iter, net, loss, trainer, ctx, num_epochs, print_batches=None):
A
softmax  
Aston Zhang 已提交
164
    """Train and evaluate a model."""
M
muli 已提交
165
    print("training on", ctx)
M
muli 已提交
166 167
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
A
Aston Zhang 已提交
168 169 170 171
    for epoch in range(1, num_epochs + 1):
        train_l_sum, train_acc_sum, n, m = 0.0, 0.0, 0.0, 0.0
        if isinstance(train_iter, mx.io.MXDataIter):
            train_iter.reset()
M
muli 已提交
172
        start = time()
A
Aston Zhang 已提交
173 174 175
        for i, batch in enumerate(train_iter):
            Xs, ys, batch_size = _get_batch(batch, ctx)
            ls = []
M
muli 已提交
176
            with autograd.record():
A
Aston Zhang 已提交
177 178 179
                y_hats = [net(X) for X in Xs]
                ls = [loss(y_hat, y) for y_hat, y in zip(y_hats, ys)]
            for l in ls:
M
muli 已提交
180
                l.backward()
A
Aston Zhang 已提交
181 182 183
            train_acc_sum += sum([(y_hat.argmax(axis=1) == y).sum().asscalar()
                                 for y_hat, y in zip(y_hats, ys)])
            train_l_sum += sum([l.sum().asscalar() for l in ls])
M
muli 已提交
184 185
            trainer.step(batch_size)
            n += batch_size
A
Aston Zhang 已提交
186
            m += sum([y.size for y in ys])
M
muli 已提交
187
            if print_batches and (i+1) % print_batches == 0:
A
Aston Zhang 已提交
188 189
                print("batch %d, loss %f, train acc %f" % (
                    n, train_l_sum / n, train_acc_sum / m
M
muli 已提交
190
                ))
A
Aston Zhang 已提交
191 192 193
        test_acc = evaluate_accuracy(test_iter, net, ctx)
        print("epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec" % (
            epoch, train_l_sum / n, train_acc_sum / m, test_acc, time() - start
M
muli 已提交
194 195
        ))

A
softmax  
Aston Zhang 已提交
196

M
muli 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
class Residual(nn.HybridBlock):
    def __init__(self, channels, same_shape=True, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.same_shape = same_shape
        with self.name_scope():
            strides = 1 if same_shape else 2
            self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1,
                                  strides=strides)
            self.bn1 = nn.BatchNorm()
            self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm()
            if not same_shape:
                self.conv3 = nn.Conv2D(channels, kernel_size=1,
                                      strides=strides)

    def hybrid_forward(self, F, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if not self.same_shape:
            x = self.conv3(x)
        return F.relu(out + x)

def resnet18(num_classes):
    net = nn.HybridSequential()
    with net.name_scope():
        net.add(
            nn.BatchNorm(),
            nn.Conv2D(64, kernel_size=3, strides=1),
            nn.MaxPool2D(pool_size=3, strides=2),
            Residual(64),
            Residual(64),
            Residual(128, same_shape=False),
            Residual(128),
            Residual(256, same_shape=False),
            Residual(256),
            nn.GlobalAvgPool2D(),
            nn.Dense(num_classes)
        )
    return net

A
Aston Zhang 已提交
237
def show_images(imgs, num_rows, num_cols, scale=2):
M
muli 已提交
238 239 240 241 242 243 244 245 246 247
    """plot a list of images"""
    figsize = (num_cols*scale, num_rows*scale)
    _, figs = plt.subplots(num_rows, num_cols, figsize=figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            figs[i][j].imshow(imgs[i*num_cols+j].asnumpy())
            figs[i][j].axes.get_xaxis().set_visible(False)
            figs[i][j].axes.get_yaxis().set_visible(False)
    plt.show()

A
utils  
Aston Zhang 已提交
248 249 250 251 252 253

def to_onehot(X, size):
    """Represent inputs with one-hot encoding."""
    return [nd.one_hot(x, size) for x in X.T]


M
muli 已提交
254 255 256 257 258 259 260
def data_iter_random(corpus_indices, batch_size, num_steps, ctx=None):
    """Sample mini-batches in a random order from sequential data."""
    num_examples = (len(corpus_indices) - 1) // num_steps
    epoch_size = num_examples // batch_size
    example_indices = list(range(num_examples))
    random.shuffle(example_indices)
    def _data(pos):
A
Aston Zhang 已提交
261
        return corpus_indices[pos : pos+num_steps]
M
muli 已提交
262 263
    for i in range(epoch_size):
        i = i * batch_size
A
Aston Zhang 已提交
264
        batch_indices = example_indices[i : i+batch_size]
A
utils  
Aston Zhang 已提交
265
        X = nd.array(
M
muli 已提交
266
            [_data(j * num_steps) for j in batch_indices], ctx=ctx)
A
utils  
Aston Zhang 已提交
267
        Y = nd.array(
M
muli 已提交
268
            [_data(j * num_steps + 1) for j in batch_indices], ctx=ctx)
A
utils  
Aston Zhang 已提交
269 270
        yield X, Y

M
muli 已提交
271 272 273 274 275 276

def data_iter_consecutive(corpus_indices, batch_size, num_steps, ctx=None):
    """Sample mini-batches in a consecutive order from sequential data."""
    corpus_indices = nd.array(corpus_indices, ctx=ctx)
    data_len = len(corpus_indices)
    batch_len = data_len // batch_size
A
Aston Zhang 已提交
277
    indices = corpus_indices[0 : batch_size*batch_len].reshape((
M
muli 已提交
278 279 280 281
        batch_size, batch_len))
    epoch_size = (batch_len - 1) // num_steps
    for i in range(epoch_size):
        i = i * num_steps
A
Aston Zhang 已提交
282 283
        X = indices[:, i : i+num_steps]
        Y = indices[:, i+1 : i+num_steps+1]
A
utils  
Aston Zhang 已提交
284
        yield X, Y
M
muli 已提交
285 286


A
utils  
Aston Zhang 已提交
287 288 289
def grad_clipping(params, theta, ctx):
    """Clip the gradient."""
    if theta is not None:
M
muli 已提交
290
        norm = nd.array([0.0], ctx)
A
utils  
Aston Zhang 已提交
291 292 293 294 295 296
        for param in params:
            norm += (param.grad ** 2).sum()
        norm = norm.sqrt().asscalar()
        if norm > theta:
            for param in params:
                param.grad[:] *= theta / norm 
M
muli 已提交
297 298


A
utils  
Aston Zhang 已提交
299 300
def predict_rnn(rnn, prefix, num_chars, params, num_hiddens, vocab_size, ctx,
                idx_to_char, char_to_idx, get_inputs, is_lstm=False):
M
muli 已提交
301 302
    """Predict the next chars given the prefix."""
    prefix = prefix.lower()
A
utils  
Aston Zhang 已提交
303
    state_h = nd.zeros(shape=(1, num_hiddens), ctx=ctx)
M
muli 已提交
304
    if is_lstm:
A
utils  
Aston Zhang 已提交
305
        state_c = nd.zeros(shape=(1, num_hiddens), ctx=ctx)
M
muli 已提交
306 307 308 309
    output = [char_to_idx[prefix[0]]]
    for i in range(num_chars + len(prefix)):
        X = nd.array([output[-1]], ctx=ctx)
        if is_lstm:
A
utils  
Aston Zhang 已提交
310 311
            Y, state_h, state_c = rnn(get_inputs(X, vocab_size), state_h,
                                      state_c, *params)
M
muli 已提交
312
        else:
A
utils  
Aston Zhang 已提交
313 314 315
            Y, state_h = rnn(get_inputs(X, vocab_size), state_h, *params)
        if i < len(prefix) - 1:
            next_input = char_to_idx[prefix[i + 1]]
M
muli 已提交
316 317 318 319 320 321
        else:
            next_input = int(Y[0].argmax(axis=1).asscalar())
        output.append(next_input)
    return ''.join([idx_to_char[i] for i in output])


A
utils  
Aston Zhang 已提交
322 323 324 325 326
def train_and_predict_rnn(rnn, is_random_iter, num_epochs, num_steps,
                          num_hiddens, lr, clipping_theta, batch_size,
                          vocab_size, pred_period, pred_len, prefixes,
                          get_params, get_inputs, ctx, corpus_indices,
                          idx_to_char, char_to_idx, is_lstm=False):
M
muli 已提交
327 328 329 330 331 332
    """Train an RNN model and predict the next item in the sequence."""
    if is_random_iter:
        data_iter = data_iter_random
    else:
        data_iter = data_iter_consecutive
    params = get_params()
A
utils  
Aston Zhang 已提交
333
    loss = gloss.SoftmaxCrossEntropyLoss()
M
muli 已提交
334

A
utils  
Aston Zhang 已提交
335
    for epoch in range(1, num_epochs + 1):
M
muli 已提交
336
        if not is_random_iter:
A
utils  
Aston Zhang 已提交
337
            state_h = nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx)
M
muli 已提交
338
            if is_lstm:
A
utils  
Aston Zhang 已提交
339 340
                state_c = nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx)
        train_l_sum = nd.array([0], ctx=ctx)
A
Aston Zhang 已提交
341
        train_l_cnt = 0
A
utils  
Aston Zhang 已提交
342
        for X, Y in data_iter(corpus_indices, batch_size, num_steps, ctx):
M
muli 已提交
343
            if is_random_iter:
A
utils  
Aston Zhang 已提交
344
                state_h = nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx)
M
muli 已提交
345
                if is_lstm:
A
utils  
Aston Zhang 已提交
346 347
                    state_c = nd.zeros(shape=(batch_size, num_hiddens),
                                       ctx=ctx)
A
Aston Zhang 已提交
348 349 350 351
            else:
                state_h = state_h.detach()
                if is_lstm:
                    state_c = state_c.detach()       
M
muli 已提交
352 353
            with autograd.record():
                if is_lstm:
A
utils  
Aston Zhang 已提交
354 355
                    outputs, state_h, state_c = rnn(
                        get_inputs(X, vocab_size), state_h, state_c, *params) 
M
muli 已提交
356
                else:
A
utils  
Aston Zhang 已提交
357 358
                    outputs, state_h = rnn(
                        get_inputs(X, vocab_size), state_h, *params)
A
Aston Zhang 已提交
359
                y = Y.T.reshape((-1,))
M
muli 已提交
360
                outputs = nd.concat(*outputs, dim=0)
A
Aston Zhang 已提交
361
                l = loss(outputs, y)
A
utils  
Aston Zhang 已提交
362 363 364
            l.backward()
            grad_clipping(params, clipping_theta, ctx)
            sgd(params, lr, 1)
A
Aston Zhang 已提交
365 366
            train_l_sum = train_l_sum + l.sum()
            train_l_cnt += l.size
A
Aston Zhang 已提交
367
        if epoch % pred_period == 0:
A
utils  
Aston Zhang 已提交
368
            print("\nepoch %d, perplexity %f"
A
Aston Zhang 已提交
369
                  % (epoch, (train_l_sum / train_l_cnt).exp().asscalar()))
A
utils  
Aston Zhang 已提交
370 371 372 373
            for prefix in prefixes:
                print(' - ', predict_rnn(
                    rnn, prefix, pred_len, params, num_hiddens, vocab_size,
                    ctx, idx_to_char, char_to_idx, get_inputs, is_lstm))
M
muli 已提交
374 375


376 377 378 379
def data_iter(batch_size, num_examples, features, labels):
    """Iterate through a data set."""
    indices = list(range(num_examples))
    random.shuffle(indices)
M
muli 已提交
380
    for i in range(0, num_examples, batch_size):
381 382
        j = nd.array(indices[i: min(i + batch_size, num_examples)])
        yield features.take(j), labels.take(j)
M
muli 已提交
383 384 385


def linreg(X, w, b):
386
    """Linear regression."""
M
muli 已提交
387 388 389
    return nd.dot(X, w) + b


390 391 392
def squared_loss(y_hat, y):
    """Squared loss."""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
M
muli 已提交
393 394


395 396 397 398 399 400 401 402 403
def optimize(batch_size, trainer, num_epochs, decay_epoch, log_interval,
             features, labels, net):
    """Optimize an objective function."""
    dataset = gdata.ArrayDataset(features, labels)
    data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True)
    loss = gloss.L2Loss()
    ls = [loss(net(features), labels).mean().asnumpy()]
    for epoch in range(1, num_epochs + 1):
        # Decay the learning rate.
M
muli 已提交
404 405
        if decay_epoch and epoch > decay_epoch:
            trainer.set_learning_rate(trainer.learning_rate * 0.1)
406
        for batch_i, (X, y) in enumerate(data_iter):
M
muli 已提交
407
            with autograd.record():
408 409
                l = loss(net(X), y)
            l.backward()
M
muli 已提交
410 411
            trainer.step(batch_size)
            if batch_i * batch_size % log_interval == 0:
412 413
                ls.append(loss(net(features), labels).mean().asnumpy())
    # To print more conveniently, use numpy.
M
muli 已提交
414
    print('w:', net[0].weight.data(), '\nb:', net[0].bias.data(), '\n')
415 416
    es = np.linspace(0, num_epochs, len(ls), endpoint=True)
    semilogy(es, ls, 'epoch', 'loss')
M
muli 已提交
417 418


A
Aston Zhang 已提交
419 420
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,
             legend=None, figsize=(3.5, 2.5)):
421
    """Plot x and log(y)."""
A
Aston Zhang 已提交
422 423
    plt.rcParams['figure.figsize'] = figsize
    set_matplotlib_formats('retina')
M
muli 已提交
424 425
    plt.xlabel(x_label)
    plt.ylabel(y_label)
A
Aston Zhang 已提交
426 427 428 429
    plt.semilogy(x_vals, y_vals)
    if x2_vals and y2_vals:
        plt.semilogy(x2_vals, y2_vals)
        plt.legend(legend)
M
muli 已提交
430
    plt.show()