movielens.py 7.7 KB
Newer Older
K
Kaipeng Deng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import zipfile
import re

from paddle.io import Dataset
20
from paddle.dataset.common import _check_exists_and_download
K
Kaipeng Deng 已提交
21

22 23
__all__ = []

K
Kaipeng Deng 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
age_table = [1, 18, 25, 35, 45, 50, 56]

URL = 'https://dataset.bj.bcebos.com/movielens%2Fml-1m.zip'
MD5 = 'c4d9eecfca2ab87c1945afe126590906'


class MovieInfo(object):
    """
    Movie id, title and categories information are stored in MovieInfo.
    """

    def __init__(self, index, categories, title):
        self.index = int(index)
        self.categories = categories
        self.title = title

    def value(self, categories_dict, movie_title_dict):
        """
        Get information from a movie.
        """
44 45 46 47 48
        return [
            [self.index],
            [categories_dict[c] for c in self.categories],
            [movie_title_dict[w.lower()] for w in self.title.split()],
        ]
K
Kaipeng Deng 已提交
49 50 51

    def __str__(self):
        return "<MovieInfo id(%d), title(%s), categories(%s)>" % (
52 53 54 55
            self.index,
            self.title,
            self.categories,
        )
K
Kaipeng Deng 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

    def __repr__(self):
        return self.__str__()


class UserInfo(object):
    """
    User id, gender, age, and job information are stored in UserInfo.
    """

    def __init__(self, index, gender, age, job_id):
        self.index = int(index)
        self.is_male = gender == 'M'
        self.age = age_table.index(int(age))
        self.job_id = int(job_id)

    def value(self):
        """
        Get information from a user.
        """
76 77 78 79 80 81
        return [
            [self.index],
            [0 if self.is_male else 1],
            [self.age],
            [self.job_id],
        ]
K
Kaipeng Deng 已提交
82 83 84

    def __str__(self):
        return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % (
85 86 87 88 89
            self.index,
            "M" if self.is_male else "F",
            age_table[self.age],
            self.job_id,
        )
K
Kaipeng Deng 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114

    def __repr__(self):
        return str(self)


class Movielens(Dataset):
    """
    Implementation of `Movielens 1-M <https://grouplens.org/datasets/movielens/1m/>`_ dataset.

    Args:
        data_file(str): path to data tar file, can be set None if
            :attr:`download` is True. Default None
        mode(str): 'train' or 'test' mode. Default 'train'.
        test_ratio(float): split ratio for test sample. Default 0.1.
        rand_seed(int): random seed. Default 0.
        download(bool): whether to download dataset automatically if
            :attr:`data_file` is not set. Default True

    Returns:
        Dataset: instance of Movielens 1-M dataset

    Examples:

        .. code-block:: python

115 116
            import paddle
            from paddle.text.datasets import Movielens
K
Kaipeng Deng 已提交
117

118 119 120
            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super(SimpleNet, self).__init__()
K
Kaipeng Deng 已提交
121

122 123
                def forward(self, category, title, rating):
                    return paddle.sum(category), paddle.sum(title), paddle.sum(rating)
K
Kaipeng Deng 已提交
124 125


126
            movielens = Movielens(mode='train')
K
Kaipeng Deng 已提交
127

128 129 130 131 132
            for i in range(10):
                category, title, rating = movielens[i][-3:]
                category = paddle.to_tensor(category)
                title = paddle.to_tensor(title)
                rating = paddle.to_tensor(rating)
K
Kaipeng Deng 已提交
133

134 135 136
                model = SimpleNet()
                category, title, rating = model(category, title, rating)
                print(category.numpy().shape, title.numpy().shape, rating.numpy().shape)
K
Kaipeng Deng 已提交
137 138 139

    """

140 141 142 143 144 145 146 147 148 149 150 151
    def __init__(
        self,
        data_file=None,
        mode='train',
        test_ratio=0.1,
        rand_seed=0,
        download=True,
    ):
        assert mode.lower() in [
            'train',
            'test',
        ], "mode should be 'train', 'test', but got {}".format(mode)
K
Kaipeng Deng 已提交
152 153 154 155
        self.mode = mode.lower()

        self.data_file = data_file
        if self.data_file is None:
156 157 158 159 160 161
            assert (
                download
            ), "data_file is not set and downloading automatically is disabled"
            self.data_file = _check_exists_and_download(
                data_file, URL, MD5, 'sentiment', download
            )
K
Kaipeng Deng 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182

        self.test_ratio = test_ratio
        self.rand_seed = rand_seed

        np.random.seed(rand_seed)
        self._load_meta_info()
        self._load_data()

    def _load_meta_info(self):
        pattern = re.compile(r'^(.*)\((\d+)\)$')
        self.movie_info = dict()
        self.movie_title_dict = dict()
        self.categories_dict = dict()
        self.user_info = dict()
        with zipfile.ZipFile(self.data_file) as package:
            for info in package.infolist():
                assert isinstance(info, zipfile.ZipInfo)
                title_word_set = set()
                categories_set = set()
                with package.open('ml-1m/movies.dat') as movie_file:
                    for i, line in enumerate(movie_file):
183
                        line = line.decode(encoding='latin')
K
Kaipeng Deng 已提交
184 185 186 187 188 189
                        movie_id, title, categories = line.strip().split('::')
                        categories = categories.split('|')
                        for c in categories:
                            categories_set.add(c)
                        title = pattern.match(title).group(1)
                        self.movie_info[int(movie_id)] = MovieInfo(
190 191
                            index=movie_id, categories=categories, title=title
                        )
K
Kaipeng Deng 已提交
192 193 194 195 196 197 198 199 200 201 202
                        for w in title.split():
                            title_word_set.add(w.lower())

                for i, w in enumerate(title_word_set):
                    self.movie_title_dict[w] = i

                for i, c in enumerate(categories_set):
                    self.categories_dict[c] = i

                with package.open('ml-1m/users.dat') as user_file:
                    for line in user_file:
203
                        line = line.decode(encoding='latin')
K
Kaipeng Deng 已提交
204
                        uid, gender, age, job, _ = line.strip().split("::")
205 206 207
                        self.user_info[int(uid)] = UserInfo(
                            index=uid, gender=gender, age=age, job_id=job
                        )
K
Kaipeng Deng 已提交
208 209 210 211 212 213 214

    def _load_data(self):
        self.data = []
        is_test = self.mode == 'test'
        with zipfile.ZipFile(self.data_file) as package:
            with package.open('ml-1m/ratings.dat') as rating:
                for line in rating:
215
                    line = line.decode(encoding='latin')
K
Kaipeng Deng 已提交
216 217 218 219 220 221 222 223
                    if (np.random.random() < self.test_ratio) == is_test:
                        uid, mov_id, rating, _ = line.strip().split("::")
                        uid = int(uid)
                        mov_id = int(mov_id)
                        rating = float(rating) * 2 - 5.0

                        mov = self.movie_info[mov_id]
                        usr = self.user_info[uid]
224 225 226 227 228 229 230
                        self.data.append(
                            usr.value()
                            + mov.value(
                                self.categories_dict, self.movie_title_dict
                            )
                            + [[rating]]
                        )
K
Kaipeng Deng 已提交
231 232 233 234 235 236 237

    def __getitem__(self, idx):
        data = self.data[idx]
        return tuple([np.array(d) for d in data])

    def __len__(self):
        return len(self.data)