# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import re import six import string import tarfile import numpy as np import collections from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download __all__ = ['Imdb'] URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz' MD5 = '7c2ac02c03563afcf9b574c7e56c153a' class Imdb(Dataset): """ Implementation of `IMDB `_ dataset. Args: data_file(str): path to data tar file, can be set None if :attr:`download` is True. Default None mode(str): 'train' 'test' mode. Default 'train'. cutoff(int): cutoff number for building word dictionary. Default 150. download(bool): whether to download dataset automatically if :attr:`data_file` is not set. Default True Returns: Dataset: instance of IMDB dataset Examples: .. code-block:: python import paddle from paddle.text.datasets import Imdb class SimpleNet(paddle.nn.Layer): def __init__(self): super(SimpleNet, self).__init__() def forward(self, doc, label): return paddle.sum(doc), label paddle.disable_static() imdb = Imdb(mode='train') for i in range(10): doc, label = imdb[i] doc = paddle.to_tensor(doc) label = paddle.to_tensor(label) model = SimpleNet() image, label = model(doc, label) print(doc.numpy().shape, label.numpy().shape) """ def __init__(self, data_file=None, mode='train', cutoff=150, download=True): assert mode.lower() in ['train', 'test'], \ "mode should be 'train', 'test', but got {}".format(mode) self.mode = mode.lower() self.data_file = data_file if self.data_file is None: assert download, "data_file is not set and downloading automatically is disabled" self.data_file = _check_exists_and_download(data_file, URL, MD5, 'imdb', download) # Build a word dictionary from the corpus self.word_idx = self._build_work_dict(cutoff) # read dataset into memory self._load_anno() def _build_work_dict(self, cutoff): word_freq = collections.defaultdict(int) pattern = re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$") for doc in self._tokenize(pattern): for word in doc: word_freq[word] += 1 # Not sure if we should prune less-frequent words here. word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff] dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*dictionary)) word_idx = dict(list(zip(words, six.moves.range(len(words))))) word_idx[''] = len(words) return word_idx def _tokenize(self, pattern): data = [] with tarfile.open(self.data_file) as tarf: tf = tarf.next() while tf != None: if bool(pattern.match(tf.name)): # newline and punctuations removal and ad-hoc tokenization. data.append( tarf.extractfile(tf).read().rstrip(six.b("\n\r")) .translate(None, six.b(string.punctuation)).lower( ).split()) tf = tarf.next() return data def _load_anno(self): pos_pattern = re.compile(r"aclImdb/{}/pos/.*\.txt$".format(self.mode)) neg_pattern = re.compile(r"aclImdb/{}/neg/.*\.txt$".format(self.mode)) UNK = self.word_idx[''] self.docs = [] self.labels = [] for doc in self._tokenize(pos_pattern): self.docs.append([self.word_idx.get(w, UNK) for w in doc]) self.labels.append(0) for doc in self._tokenize(neg_pattern): self.docs.append([self.word_idx.get(w, UNK) for w in doc]) self.labels.append(1) def __getitem__(self, idx): return (np.array(self.docs[idx]), np.array([self.labels[idx]])) def __len__(self): return len(self.docs)