reader.py 4.7 KB
Newer Older
Y
add din  
yaoxuefeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function

import os
import random
T
tangwei 已提交
18

Y
add din  
yaoxuefeng 已提交
19 20 21 22 23
try:
    import cPickle as pickle
except ImportError:
    import pickle

T
tangwei 已提交
24 25 26 27 28 29
import numpy as np

from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs


Y
add din  
yaoxuefeng 已提交
30 31 32 33 34
class TrainReader(Reader):
    def init(self):
        self.train_data_path = envs.get_global_env("train_data_path", None, "train.reader")
        self.res = []
        self.max_len = 0
T
for mat  
tangwei 已提交
35

Y
add din  
yaoxuefeng 已提交
36
        data_file_list = os.listdir(self.train_data_path)
T
for mat  
tangwei 已提交
37
        for i in range(0, len(data_file_list)):
Y
add din  
yaoxuefeng 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
            train_data_file = os.path.join(self.train_data_path, data_file_list[i])
            with open(train_data_file, "r") as fin:
                for line in fin:
                    line = line.strip().split(';')
                    hist = line[0].split()
                    self.max_len = max(self.max_len, len(hist))
        fo = open("tmp.txt", "w")
        fo.write(str(self.max_len))
        fo.close()
        self.batch_size = envs.get_global_env("batch_size", 32, "train.reader")
        self.group_size = self.batch_size * 20

    def _process_line(self, line):
        line = line.strip().split(';')
        hist = line[0].split()
        hist = [int(i) for i in hist]
        cate = line[1].split()
        cate = [int(i) for i in cate]
        return [hist, cate, [int(line[2])], [int(line[3])], [float(line[4])]]

    def generate_sample(self, line):
        """
        Read the data line by line and process it as a dictionary
        """
T
for mat  
tangwei 已提交
62

Y
add din  
yaoxuefeng 已提交
63
        def data_iter():
T
for mat  
tangwei 已提交
64
            # feat_idx, feat_value, label = self._process_line(line)
Y
add din  
yaoxuefeng 已提交
65 66 67
            yield self._process_line(line)

        return data_iter
T
for mat  
tangwei 已提交
68

Y
add din  
yaoxuefeng 已提交
69 70 71 72
    def pad_batch_data(self, input, max_len):
        res = np.array([x + [0] * (max_len - len(x)) for x in input])
        res = res.astype("int64").reshape([-1, max_len])
        return res
T
for mat  
tangwei 已提交
73

Y
add din  
yaoxuefeng 已提交
74 75 76 77 78 79 80
    def make_data(self, b):
        max_len = max(len(x[0]) for x in b)
        item = self.pad_batch_data([x[0] for x in b], max_len)
        cat = self.pad_batch_data([x[1] for x in b], max_len)
        len_array = [len(x[0]) for x in b]
        mask = np.array(
            [[0] * x + [-1e9] * (max_len - x) for x in len_array]).reshape(
T
for mat  
tangwei 已提交
81
            [-1, max_len, 1])
Y
add din  
yaoxuefeng 已提交
82 83 84 85 86 87 88 89 90 91 92
        target_item_seq = np.array(
            [[x[2]] * max_len for x in b]).astype("int64").reshape([-1, max_len])
        target_cat_seq = np.array(
            [[x[3]] * max_len for x in b]).astype("int64").reshape([-1, max_len])
        res = []
        for i in range(len(b)):
            res.append([
                item[i], cat[i], b[i][2], b[i][3], b[i][4], mask[i],
                target_item_seq[i], target_cat_seq[i]
            ])
        return res
T
for mat  
tangwei 已提交
93

Y
add din  
yaoxuefeng 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    def batch_reader(self, reader, batch_size, group_size):
        def batch_reader():
            bg = []
            for line in reader:
                bg.append(line)
                if len(bg) == group_size:
                    sortb = sorted(bg, key=lambda x: len(x[0]), reverse=False)
                    bg = []
                    for i in range(0, group_size, batch_size):
                        b = sortb[i:i + batch_size]
                        yield self.make_data(b)
            len_bg = len(bg)
            if len_bg != 0:
                sortb = sorted(bg, key=lambda x: len(x[0]), reverse=False)
                bg = []
                remain = len_bg % batch_size
                for i in range(0, len_bg - remain, batch_size):
                    b = sortb[i:i + batch_size]
                    yield self.make_data(b)

        return batch_reader
T
for mat  
tangwei 已提交
115

Y
add din  
yaoxuefeng 已提交
116 117 118 119 120 121 122 123 124 125
    def base_read(self, file_dir):
        res = []
        for train_file in file_dir:
            with open(train_file, "r") as fin:
                for line in fin:
                    line = line.strip().split(';')
                    hist = line[0].split()
                    cate = line[1].split()
                    res.append([hist, cate, line[2], line[3], float(line[4])])
        return res
T
for mat  
tangwei 已提交
126

Y
add din  
yaoxuefeng 已提交
127 128 129 130
    def generate_batch_from_trainfiles(self, files):
        data_set = self.base_read(files)
        random.shuffle(data_set)
        return self.batch_reader(data_set, self.batch_size, self.batch_size * 20)