random_reader.py 2.2 KB
Newer Older
F
frankwhzhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function

16 17 18
import numpy as np
import paddle.fluid as fluid

C
Chengmo 已提交
19
from paddlerec.core.reader import ReaderBase
F
frankwhzhang 已提交
20 21 22 23
from paddlerec.core.utils import envs
from collections import defaultdict


C
Chengmo 已提交
24
class Reader(ReaderBase):
F
frankwhzhang 已提交
25
    def init(self):
F
frankwhzhang 已提交
26 27 28 29
        self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab")
        self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab")
        self.item_len = envs.get_global_env("hyper_parameters.item_len")
        self.batch_size = envs.get_global_env("hyper_parameters.batch_size")
F
frankwhzhang 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42

    def reader_creator(self):
        def reader():
            user_slot_name = []
            for j in range(self.batch_size):
                user_slot_name.append(
                    [int(np.random.randint(self.user_vocab))])
            item_slot_name = np.random.randint(
                self.item_vocab, size=(self.batch_size,
                                       self.item_len)).tolist()
            length = [self.item_len] * self.batch_size
            label = np.random.randint(
                2, size=(self.batch_size, self.item_len)).tolist()
F
frankwhzhang 已提交
43
            output = [user_slot_name, item_slot_name, length, label]
F
frankwhzhang 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

            yield output

        return reader

    def generate_batch_from_trainfiles(self, files):
        return fluid.io.batch(
            self.reader_creator(), batch_size=self.batch_size)

    def generate_sample(self, line):
        """
        the file is not used
        """

        def reader():
            """
            This function needs to be implemented by the user, based on data format
            """
            pass

        return reader