reader.py 4.9 KB
Newer Older
Y
add din  
yaoxuefeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function

import os
import random
T
tangwei 已提交
18

Y
add din  
yaoxuefeng 已提交
19 20 21 22 23
try:
    import cPickle as pickle
except ImportError:
    import pickle

T
tangwei 已提交
24 25 26 27 28 29
import numpy as np

from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs


Y
add din  
yaoxuefeng 已提交
30 31
class TrainReader(Reader):
    def init(self):
32 33
        self.train_data_path = envs.get_global_env(
            "dataset.sample_1.data_path", None)
Y
add din  
yaoxuefeng 已提交
34 35
        self.res = []
        self.max_len = 0
T
for mat  
tangwei 已提交
36

Y
add din  
yaoxuefeng 已提交
37
        data_file_list = os.listdir(self.train_data_path)
T
for mat  
tangwei 已提交
38
        for i in range(0, len(data_file_list)):
T
tangwei 已提交
39 40
            train_data_file = os.path.join(self.train_data_path,
                                           data_file_list[i])
Y
add din  
yaoxuefeng 已提交
41 42 43 44 45 46 47 48
            with open(train_data_file, "r") as fin:
                for line in fin:
                    line = line.strip().split(';')
                    hist = line[0].split()
                    self.max_len = max(self.max_len, len(hist))
        fo = open("tmp.txt", "w")
        fo.write(str(self.max_len))
        fo.close()
49 50
        self.batch_size = envs.get_global_env("dataset.sample_1.batch_size",
                                              32, "train.reader")
Y
add din  
yaoxuefeng 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64
        self.group_size = self.batch_size * 20

    def _process_line(self, line):
        line = line.strip().split(';')
        hist = line[0].split()
        hist = [int(i) for i in hist]
        cate = line[1].split()
        cate = [int(i) for i in cate]
        return [hist, cate, [int(line[2])], [int(line[3])], [float(line[4])]]

    def generate_sample(self, line):
        """
        Read the data line by line and process it as a dictionary
        """
T
for mat  
tangwei 已提交
65

Y
add din  
yaoxuefeng 已提交
66
        def data_iter():
T
for mat  
tangwei 已提交
67
            # feat_idx, feat_value, label = self._process_line(line)
Y
add din  
yaoxuefeng 已提交
68 69 70
            yield self._process_line(line)

        return data_iter
T
for mat  
tangwei 已提交
71

Y
add din  
yaoxuefeng 已提交
72 73 74 75
    def pad_batch_data(self, input, max_len):
        res = np.array([x + [0] * (max_len - len(x)) for x in input])
        res = res.astype("int64").reshape([-1, max_len])
        return res
T
for mat  
tangwei 已提交
76

Y
add din  
yaoxuefeng 已提交
77 78 79 80 81 82 83
    def make_data(self, b):
        max_len = max(len(x[0]) for x in b)
        item = self.pad_batch_data([x[0] for x in b], max_len)
        cat = self.pad_batch_data([x[1] for x in b], max_len)
        len_array = [len(x[0]) for x in b]
        mask = np.array(
            [[0] * x + [-1e9] * (max_len - x) for x in len_array]).reshape(
T
tangwei 已提交
84
                [-1, max_len, 1])
Y
add din  
yaoxuefeng 已提交
85
        target_item_seq = np.array(
T
tangwei 已提交
86 87
            [[x[2]] * max_len for x in b]).astype("int64").reshape(
                [-1, max_len])
Y
add din  
yaoxuefeng 已提交
88
        target_cat_seq = np.array(
T
tangwei 已提交
89 90
            [[x[3]] * max_len for x in b]).astype("int64").reshape(
                [-1, max_len])
Y
add din  
yaoxuefeng 已提交
91 92 93 94 95 96 97
        res = []
        for i in range(len(b)):
            res.append([
                item[i], cat[i], b[i][2], b[i][3], b[i][4], mask[i],
                target_item_seq[i], target_cat_seq[i]
            ])
        return res
T
for mat  
tangwei 已提交
98

Y
add din  
yaoxuefeng 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    def batch_reader(self, reader, batch_size, group_size):
        def batch_reader():
            bg = []
            for line in reader:
                bg.append(line)
                if len(bg) == group_size:
                    sortb = sorted(bg, key=lambda x: len(x[0]), reverse=False)
                    bg = []
                    for i in range(0, group_size, batch_size):
                        b = sortb[i:i + batch_size]
                        yield self.make_data(b)
            len_bg = len(bg)
            if len_bg != 0:
                sortb = sorted(bg, key=lambda x: len(x[0]), reverse=False)
                bg = []
                remain = len_bg % batch_size
                for i in range(0, len_bg - remain, batch_size):
                    b = sortb[i:i + batch_size]
                    yield self.make_data(b)

        return batch_reader
T
for mat  
tangwei 已提交
120

Y
add din  
yaoxuefeng 已提交
121 122 123 124 125 126 127 128 129 130
    def base_read(self, file_dir):
        res = []
        for train_file in file_dir:
            with open(train_file, "r") as fin:
                for line in fin:
                    line = line.strip().split(';')
                    hist = line[0].split()
                    cate = line[1].split()
                    res.append([hist, cate, line[2], line[3], float(line[4])])
        return res
T
for mat  
tangwei 已提交
131

Y
add din  
yaoxuefeng 已提交
132 133 134
    def generate_batch_from_trainfiles(self, files):
        data_set = self.base_read(files)
        random.shuffle(data_set)
T
tangwei 已提交
135 136
        return self.batch_reader(data_set, self.batch_size,
                                 self.batch_size * 20)