reader.py 2.7 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

Z
add ssr  
zhangwenhui03 已提交
17

D
dongdaxiang 已提交
18 19 20 21
class Dataset:
    def __init__(self):
        pass

Z
add ssr  
zhangwenhui03 已提交
22

D
dongdaxiang 已提交
23 24 25 26
class Vocab:
    def __init__(self):
        pass

Z
add ssr  
zhangwenhui03 已提交
27

D
dongdaxiang 已提交
28 29 30 31
class YoochooseVocab(Vocab):
    def __init__(self):
        self.vocab = {}
        self.word_array = []
Z
add ssr  
zhangwenhui03 已提交
32

D
dongdaxiang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    def load(self, filelist):
        idx = 0
        for f in filelist:
            with open(f, "r") as fin:
                for line in fin:
                    group = line.strip().split()
                    for item in group:
                        if item not in self.vocab:
                            self.vocab[item] = idx
                            self.word_array.append(idx)
                            idx += 1
                        else:
                            self.word_array.append(self.vocab[item])

    def get_vocab(self):
        return self.vocab

    def _get_word_array(self):
        return self.word_array

Z
add ssr  
zhangwenhui03 已提交
53

D
dongdaxiang 已提交
54
class YoochooseDataset(Dataset):
Z
add ssr  
zhangwenhui03 已提交
55 56
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
D
dongdaxiang 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

    def sample_neg(self):
        return random.randint(0, self.vocab_size - 1)

    def sample_neg_from_seq(self, seq):
        return seq[random.randint(0, len(seq) - 1)]

    def _reader_creator(self, filelist, is_train):
        def reader():
            for f in filelist:
                with open(f, 'r') as fin:
                    line_idx = 0
                    for line in fin:
                        ids = line.strip().split()
                        if len(ids) <= 1:
                            continue
Z
add ssr  
zhangwenhui03 已提交
73 74
                        conv_ids = [i for i in ids]
                        boundary = len(ids) - 1
D
dongdaxiang 已提交
75 76 77
                        src = conv_ids[:boundary]
                        pos_tgt = [conv_ids[boundary]]
                        if is_train:
Z
add ssr  
zhangwenhui03 已提交
78
                            neg_tgt = [self.sample_neg()]
D
dongdaxiang 已提交
79 80 81
                            yield [src, pos_tgt, neg_tgt]
                        else:
                            yield [src, pos_tgt]
Z
add ssr  
zhangwenhui03 已提交
82

D
dongdaxiang 已提交
83 84 85 86
        return reader

    def train(self, file_list):
        return self._reader_creator(file_list, True)
Z
add ssr  
zhangwenhui03 已提交
87

D
dongdaxiang 已提交
88 89
    def test(self, file_list):
        return self._reader_creator(file_list, False)