imdb.py 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Q
qijun 已提交
15
IMDB dataset.
Y
Yu Yang 已提交
16

Q
qijun 已提交
17 18 19 20
This module downloads IMDB dataset from
http://ai.stanford.edu/%7Eamaas/data/sentiment/. This dataset contains a set
of 25,000 highly polar movie reviews for training, and 25,000 for testing.
Besides, this module also provides API for building dictionary.
21
"""
D
dangqingqing 已提交
22

23 24
from __future__ import print_function

25
import paddle.dataset.common
26
import paddle.utils.deprecated as deprecated
27
import collections
Y
Yi Wang 已提交
28 29 30
import tarfile
import re
import string
M
minqiyang 已提交
31
import six
Y
Yi Wang 已提交
32

33 34
__all__ = []

35 36
#URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
Y
Yi Wang 已提交
37 38 39 40
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'


def tokenize(pattern):
Q
qijun 已提交
41
    """
Q
qijun 已提交
42
    Read files that match the given pattern.  Tokenize and yield each file.
Q
qijun 已提交
43 44
    """

45
    with tarfile.open(paddle.dataset.common.download(URL, 'imdb', MD5)) as tarf:
Y
Yi Wang 已提交
46 47 48 49
        # Note that we should use tarfile.next(), which does
        # sequential access of member files, other than
        # tarfile.extractfile, which does random access and might
        # destroy hard disks.
M
minqiyang 已提交
50
        tf = tarf.next()
Y
Yi Wang 已提交
51 52 53
        while tf != None:
            if bool(pattern.match(tf.name)):
                # newline and punctuations removal and ad-hoc tokenization.
54 55 56
                yield tarf.extractfile(tf).read().rstrip(six.b(
                    "\n\r")).translate(
                        None, six.b(string.punctuation)).lower().split()
M
minqiyang 已提交
57
            tf = tarf.next()
Y
Yi Wang 已提交
58 59 60


def build_dict(pattern, cutoff):
Q
qijun 已提交
61
    """
Q
qijun 已提交
62 63
    Build a word dictionary from the corpus. Keys of the dictionary are words,
    and values are zero-based IDs of these words.
Q
qijun 已提交
64
    """
65
    word_freq = collections.defaultdict(int)
Y
Yi Wang 已提交
66 67
    for doc in tokenize(pattern):
        for word in doc:
68
            word_freq[word] += 1
Y
Yi Wang 已提交
69 70

    # Not sure if we should prune less-frequent words here.
M
minqiyang 已提交
71
    word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff]
Y
Yi Wang 已提交
72 73 74

    dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
    words, _ = list(zip(*dictionary))
M
minqiyang 已提交
75
    word_idx = dict(list(zip(words, six.moves.range(len(words)))))
Y
Yi Wang 已提交
76 77 78 79
    word_idx['<unk>'] = len(words)
    return word_idx


80 81 82
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imdb",
83
    level=1,
84
    reason="Please use new dataset API which supports paddle.io.DataLoader")
D
dangqingqing 已提交
85
def reader_creator(pos_pattern, neg_pattern, word_idx):
Y
Yi Wang 已提交
86
    UNK = word_idx['<unk>']
D
dangqingqing 已提交
87
    INS = []
Y
Yi Wang 已提交
88

D
dangqingqing 已提交
89
    def load(pattern, out, label):
Y
Yi Wang 已提交
90
        for doc in tokenize(pattern):
D
dangqingqing 已提交
91 92 93 94
            out.append(([word_idx.get(w, UNK) for w in doc], label))

    load(pos_pattern, INS, 0)
    load(neg_pattern, INS, 1)
Y
Yi Wang 已提交
95 96

    def reader():
D
dangqingqing 已提交
97 98
        for doc, label in INS:
            yield doc, label
Y
Yi Wang 已提交
99

F
fengjiayi 已提交
100
    return reader
Y
Yi Wang 已提交
101 102


103 104 105
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imdb",
106
    level=1,
107
    reason="Please use new dataset API which supports paddle.io.DataLoader")
Y
Yi Wang 已提交
108
def train(word_idx):
Q
qijun 已提交
109
    """
Q
qijun 已提交
110
    IMDB training set creator.
Q
qijun 已提交
111

Q
qijun 已提交
112
    It returns a reader creator, each sample in the reader is an zero-based ID
Q
qijun 已提交
113 114 115 116
    sequence and label in [0, 1].

    :param word_idx: word dictionary
    :type word_idx: dict
Q
qijun 已提交
117
    :return: Training reader creator
Q
qijun 已提交
118 119
    :rtype: callable
    """
Y
Yi Wang 已提交
120
    return reader_creator(
121 122
        re.compile(r"aclImdb/train/pos/.*\.txt$"),
        re.compile(r"aclImdb/train/neg/.*\.txt$"), word_idx)
Y
Yi Wang 已提交
123 124


125 126 127
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imdb",
128
    level=1,
129
    reason="Please use new dataset API which supports paddle.io.DataLoader")
Y
Yi Wang 已提交
130
def test(word_idx):
Q
qijun 已提交
131 132 133
    """
    IMDB test set creator.

Q
qijun 已提交
134
    It returns a reader creator, each sample in the reader is an zero-based ID
Q
qijun 已提交
135 136 137 138 139 140 141
    sequence and label in [0, 1].

    :param word_idx: word dictionary
    :type word_idx: dict
    :return: Test reader creator
    :rtype: callable
    """
Y
Yi Wang 已提交
142
    return reader_creator(
143 144
        re.compile(r"aclImdb/test/pos/.*\.txt$"),
        re.compile(r"aclImdb/test/neg/.*\.txt$"), word_idx)
H
hedaoyuan 已提交
145 146


147 148 149
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imdb",
150
    level=1,
151
    reason="Please use new dataset API which supports paddle.io.DataLoader")
H
hedaoyuan 已提交
152
def word_dict():
Q
qijun 已提交
153
    """
Q
qijun 已提交
154
    Build a word dictionary from the corpus.
Q
qijun 已提交
155 156 157 158

    :return: Word dictionary
    :rtype: dict
    """
H
hedaoyuan 已提交
159
    return build_dict(
160
        re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150)
Y
Yancey1989 已提交
161 162


163 164 165
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imdb",
166
    level=1,
167
    reason="Please use new dataset API which supports paddle.io.DataLoader")
168
def fetch():
169
    paddle.dataset.common.download(URL, 'imdb', MD5)