reader.py 2.4 KB
Newer Older
Q
Qiao Longfei 已提交
1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
D
dongdaxiang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

D
dongdaxiang 已提交
15 16
import random

17

D
dongdaxiang 已提交
18 19 20 21
class Dataset:
    def __init__(self):
        pass

22

D
dongdaxiang 已提交
23
class SyntheticDataset(Dataset):
24
    def __init__(self, sparse_feature_dim, query_slot_num, title_slot_num):
D
dongdaxiang 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
        # ids are randomly generated
        self.ids_per_slot = 10
        self.sparse_feature_dim = sparse_feature_dim
        self.query_slot_num = query_slot_num
        self.title_slot_num = title_slot_num
        self.dataset_size = 10000

    def _reader_creator(self, is_train):
        def generate_ids(num, space):
            return [random.randint(0, space - 1) for i in range(num)]

        def reader():
            for i in range(self.dataset_size):
                query_slots = []
                pos_title_slots = []
                neg_title_slots = []
                for i in range(self.query_slot_num):
42 43
                    qslot = generate_ids(self.ids_per_slot,
                                         self.sparse_feature_dim)
D
dongdaxiang 已提交
44 45
                    query_slots.append(qslot)
                for i in range(self.title_slot_num):
46 47
                    pt_slot = generate_ids(self.ids_per_slot,
                                           self.sparse_feature_dim)
D
dongdaxiang 已提交
48 49 50
                    pos_title_slots.append(pt_slot)
                if is_train:
                    for i in range(self.title_slot_num):
51 52
                        nt_slot = generate_ids(self.ids_per_slot,
                                               self.sparse_feature_dim)
D
dongdaxiang 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
                        neg_title_slots.append(nt_slot)
                    yield query_slots + pos_title_slots + neg_title_slots
                else:
                    yield query_slots + pos_title_slots

        return reader

    def train(self):
        return self._reader_creator(True)

    def valid(self):
        return self._reader_creator(True)

    def test(self):
        return self._reader_creator(False)