reader.py 3.2 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

class Dataset:
    def __init__(self):
        pass

class Vocab:
    def __init__(self):
        pass

class YoochooseVocab(Vocab):
    def __init__(self):
        self.vocab = {}
        self.word_array = []
    
    def load(self, filelist):
        idx = 0
        for f in filelist:
            with open(f, "r") as fin:
                for line in fin:
                    group = line.strip().split()
                    for item in group:
                        if item not in self.vocab:
                            self.vocab[item] = idx
                            self.word_array.append(idx)
                            idx += 1
                        else:
                            self.word_array.append(self.vocab[item])

    def get_vocab(self):
        return self.vocab

    def _get_word_array(self):
        return self.word_array

class YoochooseDataset(Dataset):
    def __init__(self, y_vocab):
        self.vocab_size = len(y_vocab.get_vocab())
        self.word_array = y_vocab._get_word_array()
        self.vocab = y_vocab.get_vocab()

    def sample_neg(self):
        return random.randint(0, self.vocab_size - 1)

    def sample_neg_from_seq(self, seq):
        return seq[random.randint(0, len(seq) - 1)]
    
    # TODO(guru4elephant): wait memory, should be improved
    def sample_from_word_freq(self):
        return self.word_array[random.randint(0, len(self.word_array) - 1)]

    def _reader_creator(self, filelist, is_train):
        def reader():
            for f in filelist:
                with open(f, 'r') as fin:
                    line_idx = 0
                    for line in fin:
                        ids = line.strip().split()
                        if len(ids) <= 1:
                            continue
                        conv_ids = [self.vocab[i] if i in self.vocab else 0 for i in ids]
                        # random select an index as boundary
                        # make ids before boundary as sequence
                        # make id next to boundary right as target
                        boundary = random.randint(1, len(ids) - 1)
                        src = conv_ids[:boundary]
                        pos_tgt = [conv_ids[boundary]]
                        if is_train:
                            neg_tgt = [self.sample_from_word_freq()]
                            yield [src, pos_tgt, neg_tgt]
                        else:
                            yield [src, pos_tgt]
        return reader

    def train(self, file_list):
        return self._reader_creator(file_list, True)
    
    def test(self, file_list):
        return self._reader_creator(file_list, False)