reader.py 2.3 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
import os.path
import io
import numpy as np
import tensorflow as tf

# tflearn
import tflearn
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.datasets import imdb


FLAGS = tf.app.flags.FLAGS

class DataSet(object):
    def __init__(self, data, labels):
        assert data.shape[0] == labels.shape[0], (
            'data.shape: %s labels.shape: %s' % (data.shape,
                                                 labels.shape))
        self._num_examples = data.shape[0]

        self._data = data
        self._labels = labels
        self._epochs_completed = 0
        self._index_in_epoch = 0

    @property
    def data(self):
        return self._data

    @property
    def labels(self):
        return self._labels

    @property
    def num_examples(self):
        return self._num_examples

    @property
    def epochs_completed(self):
        return self._epochs_completed

    def next_batch(self, batch_size):
        assert batch_size <= self._num_examples

        start = self._index_in_epoch
        self._index_in_epoch += batch_size
        if self._index_in_epoch > self._num_examples:
            # Finished epoch
            self._epochs_completed += 1
            # Shuffle the data
            perm = np.arange(self._num_examples)
            np.random.shuffle(perm)
            self._data = self._data[perm]
            self._labels = self._labels[perm]
            # Start next epoch
            start = 0
            self._index_in_epoch = batch_size

        end = self._index_in_epoch

        return self._data[start:end], self._labels[start:end]


def create_datasets(file_path, vocab_size=30000, val_fraction=0.0):

    # IMDB Dataset loading
    train, test, _ = imdb.load_data(path=file_path, n_words=vocab_size,
                                valid_portion=val_fraction, sort_by_len=False)
    trainX, trainY = train
    testX, testY = test

    # Data preprocessing
    # Sequence padding
    trainX = pad_sequences(trainX, maxlen=FLAGS.max_len, value=0.)
    testX = pad_sequences(testX, maxlen=FLAGS.max_len, value=0.)
    # Converting labels to binary vectors
    trainY = to_categorical(trainY, nb_classes=2)
    testY = to_categorical(testY, nb_classes=2)

    train_dataset = DataSet(trainX, trainY)

    return train_dataset


def main():
    create_datasets('imdb.pkl')


if __name__ == "__main__":
    main()