utils.py 16.5 KB
Newer Older
M
muli 已提交
1
from math import exp
2
import random
M
muli 已提交
3
from time import time
4

M
muli 已提交
5
from IPython.display import set_matplotlib_formats
A
Aston Zhang 已提交
6
from matplotlib import pyplot as plt
7 8 9 10 11
import mxnet as mx
from mxnet import autograd, gluon, image, nd
from mxnet.gluon import nn, data as gdata, loss as gloss, utils as gutils
import numpy as np

M
muli 已提交
12 13 14
# set default figure size
set_matplotlib_formats('retina')
plt.rcParams['figure.figsize'] = (3.5, 2.5)
A
Aston Zhang 已提交
15

M
muli 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
class DataLoader(object):
    """similiar to gluon.data.DataLoader, but might be faster.

    The main difference this data loader tries to read more exmaples each
    time. But the limits are 1) all examples in dataset have the same shape, 2)
    data transfomer needs to process multiple examples at each time
    """
    def __init__(self, dataset, batch_size, shuffle, transform=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.transform = transform

    def __iter__(self):
        data = self.dataset[:]
        X = data[0]
        y = nd.array(data[1])
        n = X.shape[0]
        if self.shuffle:
            idx = np.arange(n)
            np.random.shuffle(idx)
            X = nd.array(X.asnumpy()[idx])
            y = nd.array(y.asnumpy()[idx])

        for i in range(n//self.batch_size):
            if self.transform is not None:
M
muli 已提交
42
                yield self.transform(X[i*self.batch_size:(i+1)*self.batch_size],
M
muli 已提交
43 44 45 46 47 48 49 50
                                     y[i*self.batch_size:(i+1)*self.batch_size])
            else:
                yield (X[i*self.batch_size:(i+1)*self.batch_size],
                       y[i*self.batch_size:(i+1)*self.batch_size])

    def __len__(self):
        return len(self.dataset)//self.batch_size

A
Aston Zhang 已提交
51

M
muli 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
    """download the fashion mnist dataest and then load into memory"""
    def transform_mnist(data, label):
        # Transform a batch of examples.
        if resize:
            n = data.shape[0]
            new_data = nd.zeros((n, resize, resize, data.shape[3]))
            for i in range(n):
                new_data[i] = image.imresize(data[i], resize, resize)
            data = new_data
        # change data from batch x height x width x channel to batch x channel x height x width
        return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')

    mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
    mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
M
muli 已提交
67
    # Transform later to avoid memory explosion.
M
muli 已提交
68 69 70 71
    train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
    test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
    return (train_data, test_data)

A
Aston Zhang 已提交
72

M
muli 已提交
73 74 75 76 77 78 79 80 81
def try_gpu():
    """If GPU is available, return mx.gpu(0); else return mx.cpu()"""
    try:
        ctx = mx.gpu()
        _ = nd.array([0], ctx=ctx)
    except:
        ctx = mx.cpu()
    return ctx

A
Aston Zhang 已提交
82

M
muli 已提交
83 84
def try_all_gpus():
    """Return all available GPUs, or [mx.gpu()] if there is no GPU"""
A
Aston Zhang 已提交
85
    ctxes = []
M
muli 已提交
86 87 88 89
    try:
        for i in range(16):
            ctx = mx.gpu(i)
            _ = nd.array([0], ctx=ctx)
A
Aston Zhang 已提交
90
            ctxes.append(ctx)
M
muli 已提交
91 92
    except:
        pass
A
Aston Zhang 已提交
93 94 95 96
    if not ctxes:
        ctxes = [mx.cpu()]
    return ctxes

M
muli 已提交
97 98

def SGD(params, lr):
99
    """DEPRECATED!"""
M
muli 已提交
100 101 102
    for param in params:
        param[:] = param - lr * param.grad

A
Aston Zhang 已提交
103

M
muli 已提交
104
def sgd(params, lr, batch_size):
105
    """Mini-batch stochastic gradient descent."""
M
muli 已提交
106 107 108
    for param in params:
        param[:] = param - lr * param.grad / batch_size

A
Aston Zhang 已提交
109

A
softmax  
Aston Zhang 已提交
110
def accuracy(y_hat, y):
111
    """Get accuracy."""
A
softmax  
Aston Zhang 已提交
112
    return (y_hat.argmax(axis=1) == y).mean().asscalar()
M
muli 已提交
113

A
Aston Zhang 已提交
114

M
muli 已提交
115
def _get_batch(batch, ctx):
A
Aston Zhang 已提交
116
    """return features and labels on ctx"""
M
muli 已提交
117
    if isinstance(batch, mx.io.DataBatch):
A
Aston Zhang 已提交
118 119
        features = batch.data[0]
        labels = batch.label[0]
M
muli 已提交
120
    else:
A
Aston Zhang 已提交
121
        features, labels = batch
M
muli 已提交
122 123
    if labels.dtype != features.dtype:
        labels = labels.astype(features.dtype)
A
Aston Zhang 已提交
124 125 126
    return (gutils.split_and_load(features, ctx),
            gutils.split_and_load(labels, ctx),
            features.shape[0])
M
muli 已提交
127

128 129 130

def evaluate_accuracy(data_iter, net, ctx=[mx.cpu()]):
    """Evaluate accuracy of a model on the given data set."""
M
muli 已提交
131 132 133
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
    acc = nd.array([0])
134 135 136 137 138 139
    n = 0
    if isinstance(data_iter, mx.io.MXDataIter):
        data_iter.reset()
    for batch in data_iter:
        features, labels, batch_size = _get_batch(batch, ctx)
        for X, y in zip(features, labels):
M
muli 已提交
140
            y = y.astype('float32')
141
            acc += (net(X).argmax(axis=1)==y).sum().copyto(mx.cpu())
M
muli 已提交
142
            n += y.size
M
muli 已提交
143
        acc.wait_to_read()
M
muli 已提交
144 145 146
    return acc.asscalar() / n


147 148
def train_cpu(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, trainer=None):
A
softmax  
Aston Zhang 已提交
149
    """Train and evaluate a model on CPU."""
M
muli 已提交
150 151 152
    for epoch in range(1, num_epochs + 1):
        train_l_sum = 0
        train_acc_sum = 0
153 154 155 156 157 158 159 160 161 162 163
        for X, y in train_iter:
            with autograd.record():
                y_hat = net(X)
                l = loss(y_hat, y)
            l.backward()
            if trainer is None:
                sgd(params, lr, batch_size)
            else:
                trainer.step(batch_size)
            train_l_sum += l.mean().asscalar()
            train_acc_sum += accuracy(y_hat, y)
A
softmax  
Aston Zhang 已提交
164
        test_acc = evaluate_accuracy(test_iter, net)
M
muli 已提交
165
        print("epoch %d, loss %.4f, train acc %.3f, test acc %.3f"
166 167 168 169
              % (epoch, train_l_sum / len(train_iter),
                 train_acc_sum / len(train_iter), test_acc))


A
Aston Zhang 已提交
170
def train(train_iter, test_iter, net, loss, trainer, ctx, num_epochs, print_batches=None):
A
softmax  
Aston Zhang 已提交
171
    """Train and evaluate a model."""
M
muli 已提交
172
    print("training on", ctx)
M
muli 已提交
173 174
    if isinstance(ctx, mx.Context):
        ctx = [ctx]
A
Aston Zhang 已提交
175 176 177 178
    for epoch in range(1, num_epochs + 1):
        train_l_sum, train_acc_sum, n, m = 0.0, 0.0, 0.0, 0.0
        if isinstance(train_iter, mx.io.MXDataIter):
            train_iter.reset()
M
muli 已提交
179
        start = time()
A
Aston Zhang 已提交
180 181 182
        for i, batch in enumerate(train_iter):
            Xs, ys, batch_size = _get_batch(batch, ctx)
            ls = []
M
muli 已提交
183
            with autograd.record():
A
Aston Zhang 已提交
184 185 186
                y_hats = [net(X) for X in Xs]
                ls = [loss(y_hat, y) for y_hat, y in zip(y_hats, ys)]
            for l in ls:
M
muli 已提交
187
                l.backward()
A
Aston Zhang 已提交
188 189 190
            train_acc_sum += sum([(y_hat.argmax(axis=1) == y).sum().asscalar()
                                 for y_hat, y in zip(y_hats, ys)])
            train_l_sum += sum([l.sum().asscalar() for l in ls])
M
muli 已提交
191 192
            trainer.step(batch_size)
            n += batch_size
A
Aston Zhang 已提交
193
            m += sum([y.size for y in ys])
M
muli 已提交
194
            if print_batches and (i+1) % print_batches == 0:
A
Aston Zhang 已提交
195 196
                print("batch %d, loss %f, train acc %f" % (
                    n, train_l_sum / n, train_acc_sum / m
M
muli 已提交
197
                ))
A
Aston Zhang 已提交
198 199 200
        test_acc = evaluate_accuracy(test_iter, net, ctx)
        print("epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec" % (
            epoch, train_l_sum / n, train_acc_sum / m, test_acc, time() - start
M
muli 已提交
201 202
        ))

A
softmax  
Aston Zhang 已提交
203

M
muli 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
class Residual(nn.HybridBlock):
    def __init__(self, channels, same_shape=True, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.same_shape = same_shape
        with self.name_scope():
            strides = 1 if same_shape else 2
            self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1,
                                  strides=strides)
            self.bn1 = nn.BatchNorm()
            self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm()
            if not same_shape:
                self.conv3 = nn.Conv2D(channels, kernel_size=1,
                                      strides=strides)

    def hybrid_forward(self, F, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if not self.same_shape:
            x = self.conv3(x)
        return F.relu(out + x)

def resnet18(num_classes):
    net = nn.HybridSequential()
    with net.name_scope():
        net.add(
            nn.BatchNorm(),
            nn.Conv2D(64, kernel_size=3, strides=1),
            nn.MaxPool2D(pool_size=3, strides=2),
            Residual(64),
            Residual(64),
            Residual(128, same_shape=False),
            Residual(128),
            Residual(256, same_shape=False),
            Residual(256),
            nn.GlobalAvgPool2D(),
            nn.Dense(num_classes)
        )
    return net

M
muli 已提交
244 245 246 247 248 249 250 251 252 253 254
def show_images(imgs, num_rows, num_cols, scale=2):
    """plot a list of images"""
    figsize = (num_cols*scale, num_rows*scale)
    _, figs = plt.subplots(num_rows, num_cols, figsize=figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            figs[i][j].imshow(imgs[i*num_cols+j].asnumpy())
            figs[i][j].axes.get_xaxis().set_visible(False)
            figs[i][j].axes.get_yaxis().set_visible(False)
    plt.show()

M
muli 已提交
255 256
def data_iter_random(corpus_indices, batch_size, num_steps, ctx=None):
    """Sample mini-batches in a random order from sequential data."""
M
muli 已提交
257
    # Subtract 1 because label indices are corresponding input indices + 1.
M
muli 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
    num_examples = (len(corpus_indices) - 1) // num_steps
    epoch_size = num_examples // batch_size
    # Randomize samples.
    example_indices = list(range(num_examples))
    random.shuffle(example_indices)

    def _data(pos):
        return corpus_indices[pos: pos + num_steps]

    for i in range(epoch_size):
        # Read batch_size random samples each time.
        i = i * batch_size
        batch_indices = example_indices[i: i + batch_size]
        data = nd.array(
            [_data(j * num_steps) for j in batch_indices], ctx=ctx)
        label = nd.array(
            [_data(j * num_steps + 1) for j in batch_indices], ctx=ctx)
        yield data, label

def data_iter_consecutive(corpus_indices, batch_size, num_steps, ctx=None):
    """Sample mini-batches in a consecutive order from sequential data."""
    corpus_indices = nd.array(corpus_indices, ctx=ctx)
    data_len = len(corpus_indices)
    batch_len = data_len // batch_size
M
muli 已提交
282

M
muli 已提交
283 284
    indices = corpus_indices[0: batch_size * batch_len].reshape((
        batch_size, batch_len))
M
muli 已提交
285
    # Subtract 1 because label indices are corresponding input indices + 1.
M
muli 已提交
286
    epoch_size = (batch_len - 1) // num_steps
M
muli 已提交
287

M
muli 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
    for i in range(epoch_size):
        i = i * num_steps
        data = indices[:, i: i + num_steps]
        label = indices[:, i + 1: i + num_steps + 1]
        yield data, label


def grad_clipping(params, clipping_norm, ctx):
    """Gradient clipping."""
    if clipping_norm is not None:
        norm = nd.array([0.0], ctx)
        for p in params:
            norm += nd.sum(p.grad ** 2)
        norm = nd.sqrt(norm).asscalar()
        if norm > clipping_norm:
            for p in params:
                p.grad[:] *= clipping_norm / norm


def predict_rnn(rnn, prefix, num_chars, params, hidden_dim, ctx, idx_to_char,
                char_to_idx, get_inputs, is_lstm=False):
    """Predict the next chars given the prefix."""
    prefix = prefix.lower()
    state_h = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
    if is_lstm:
        state_c = nd.zeros(shape=(1, hidden_dim), ctx=ctx)
    output = [char_to_idx[prefix[0]]]
    for i in range(num_chars + len(prefix)):
        X = nd.array([output[-1]], ctx=ctx)
        if is_lstm:
            Y, state_h, state_c = rnn(get_inputs(X), state_h, state_c, *params)
        else:
            Y, state_h = rnn(get_inputs(X), state_h, *params)
        if i < len(prefix)-1:
            next_input = char_to_idx[prefix[i+1]]
        else:
            next_input = int(Y[0].argmax(axis=1).asscalar())
        output.append(next_input)
    return ''.join([idx_to_char[i] for i in output])


M
muli 已提交
329
def train_and_predict_rnn(rnn, is_random_iter, epochs, num_steps, hidden_dim,
M
muli 已提交
330 331 332 333 334 335 336 337 338 339
                          learning_rate, clipping_norm, batch_size,
                          pred_period, pred_len, seqs, get_params, get_inputs,
                          ctx, corpus_indices, idx_to_char, char_to_idx,
                          is_lstm=False):
    """Train an RNN model and predict the next item in the sequence."""
    if is_random_iter:
        data_iter = data_iter_random
    else:
        data_iter = data_iter_consecutive
    params = get_params()
M
muli 已提交
340

M
muli 已提交
341 342
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

M
muli 已提交
343
    for e in range(1, epochs + 1):
M
muli 已提交
344 345 346 347 348 349 350
        # If consecutive sampling is used, in the same epoch, the hidden state
        # is initialized only at the beginning of the epoch.
        if not is_random_iter:
            state_h = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
            if is_lstm:
                state_c = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
        train_loss, num_examples = 0, 0
M
muli 已提交
351
        for data, label in data_iter(corpus_indices, batch_size, num_steps,
M
muli 已提交
352 353 354 355 356 357 358 359 360 361 362
                                     ctx):
            # If random sampling is used, the hidden state has to be
            # initialized for each mini-batch.
            if is_random_iter:
                state_h = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
                if is_lstm:
                    state_c = nd.zeros(shape=(batch_size, hidden_dim), ctx=ctx)
            with autograd.record():
                # outputs shape: (batch_size, vocab_size)
                if is_lstm:
                    outputs, state_h, state_c = rnn(get_inputs(data), state_h,
M
muli 已提交
363
                                                    state_c, *params)
M
muli 已提交
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
                else:
                    outputs, state_h = rnn(get_inputs(data), state_h, *params)
                # Let t_ib_j be the j-th element of the mini-batch at time i.
                # label shape: (batch_size * num_steps)
                # label = [t_0b_0, t_0b_1, ..., t_1b_0, t_1b_1, ..., ].
                label = label.T.reshape((-1,))
                # Concatenate outputs:
                # shape: (batch_size * num_steps, vocab_size).
                outputs = nd.concat(*outputs, dim=0)
                # Now outputs and label are aligned.
                loss = softmax_cross_entropy(outputs, label)
            loss.backward()

            grad_clipping(params, clipping_norm, ctx)
            SGD(params, learning_rate)

            train_loss += nd.sum(loss).asscalar()
            num_examples += loss.size

        if e % pred_period == 0:
M
muli 已提交
384
            print("Epoch %d. Training perplexity %f" % (e,
M
muli 已提交
385 386 387 388 389 390 391 392
                                               exp(train_loss/num_examples)))
            for seq in seqs:
                print(' - ', predict_rnn(rnn, seq, pred_len, params,
                      hidden_dim, ctx, idx_to_char, char_to_idx, get_inputs,
                      is_lstm))
            print()


393 394 395 396
def data_iter(batch_size, num_examples, features, labels):
    """Iterate through a data set."""
    indices = list(range(num_examples))
    random.shuffle(indices)
M
muli 已提交
397
    for i in range(0, num_examples, batch_size):
398 399
        j = nd.array(indices[i: min(i + batch_size, num_examples)])
        yield features.take(j), labels.take(j)
M
muli 已提交
400 401 402


def linreg(X, w, b):
403
    """Linear regression."""
M
muli 已提交
404 405 406
    return nd.dot(X, w) + b


407 408 409
def squared_loss(y_hat, y):
    """Squared loss."""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
M
muli 已提交
410 411


412 413 414 415 416 417 418 419 420
def optimize(batch_size, trainer, num_epochs, decay_epoch, log_interval,
             features, labels, net):
    """Optimize an objective function."""
    dataset = gdata.ArrayDataset(features, labels)
    data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True)
    loss = gloss.L2Loss()
    ls = [loss(net(features), labels).mean().asnumpy()]
    for epoch in range(1, num_epochs + 1):
        # Decay the learning rate.
M
muli 已提交
421 422
        if decay_epoch and epoch > decay_epoch:
            trainer.set_learning_rate(trainer.learning_rate * 0.1)
423
        for batch_i, (X, y) in enumerate(data_iter):
M
muli 已提交
424
            with autograd.record():
425 426
                l = loss(net(X), y)
            l.backward()
M
muli 已提交
427 428
            trainer.step(batch_size)
            if batch_i * batch_size % log_interval == 0:
429 430
                ls.append(loss(net(features), labels).mean().asnumpy())
    # To print more conveniently, use numpy.
M
muli 已提交
431
    print('w:', net[0].weight.data(), '\nb:', net[0].bias.data(), '\n')
432 433
    es = np.linspace(0, num_epochs, len(ls), endpoint=True)
    semilogy(es, ls, 'epoch', 'loss')
M
muli 已提交
434 435


A
Aston Zhang 已提交
436 437
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,
             legend=None, figsize=(3.5, 2.5)):
438
    """Plot x and log(y)."""
A
Aston Zhang 已提交
439 440
    plt.rcParams['figure.figsize'] = figsize
    set_matplotlib_formats('retina')
M
muli 已提交
441 442
    plt.xlabel(x_label)
    plt.ylabel(y_label)
A
Aston Zhang 已提交
443 444 445 446
    plt.semilogy(x_vals, y_vals)
    if x2_vals and y2_vals:
        plt.semilogy(x2_vals, y2_vals)
        plt.legend(legend)
M
muli 已提交
447
    plt.show()