reader.py 2.7 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
Z
zhang wenhui 已提交
16
import io
D
dongdaxiang 已提交
17

Z
add ssr  
zhangwenhui03 已提交
18

D
dongdaxiang 已提交
19 20 21 22
class Dataset:
    def __init__(self):
        pass

Z
add ssr  
zhangwenhui03 已提交
23

D
dongdaxiang 已提交
24 25 26 27
class Vocab:
    def __init__(self):
        pass

Z
add ssr  
zhangwenhui03 已提交
28

D
dongdaxiang 已提交
29 30 31 32
class YoochooseVocab(Vocab):
    def __init__(self):
        self.vocab = {}
        self.word_array = []
Z
add ssr  
zhangwenhui03 已提交
33

D
dongdaxiang 已提交
34 35 36
    def load(self, filelist):
        idx = 0
        for f in filelist:
Z
zhang wenhui 已提交
37
            with io.open(f, "r", encoding='utf-8') as fin:
D
dongdaxiang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
                for line in fin:
                    group = line.strip().split()
                    for item in group:
                        if item not in self.vocab:
                            self.vocab[item] = idx
                            self.word_array.append(idx)
                            idx += 1
                        else:
                            self.word_array.append(self.vocab[item])

    def get_vocab(self):
        return self.vocab

    def _get_word_array(self):
        return self.word_array

Z
add ssr  
zhangwenhui03 已提交
54

D
dongdaxiang 已提交
55
class YoochooseDataset(Dataset):
Z
add ssr  
zhangwenhui03 已提交
56 57
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
D
dongdaxiang 已提交
58 59 60 61 62 63 64 65 66 67

    def sample_neg(self):
        return random.randint(0, self.vocab_size - 1)

    def sample_neg_from_seq(self, seq):
        return seq[random.randint(0, len(seq) - 1)]

    def _reader_creator(self, filelist, is_train):
        def reader():
            for f in filelist:
Z
zhang wenhui 已提交
68
                with io.open(f, 'r', encoding='utf-8') as fin:
D
dongdaxiang 已提交
69 70 71 72 73
                    line_idx = 0
                    for line in fin:
                        ids = line.strip().split()
                        if len(ids) <= 1:
                            continue
Z
add ssr  
zhangwenhui03 已提交
74 75
                        conv_ids = [i for i in ids]
                        boundary = len(ids) - 1
D
dongdaxiang 已提交
76 77 78
                        src = conv_ids[:boundary]
                        pos_tgt = [conv_ids[boundary]]
                        if is_train:
Z
add ssr  
zhangwenhui03 已提交
79
                            neg_tgt = [self.sample_neg()]
D
dongdaxiang 已提交
80 81 82
                            yield [src, pos_tgt, neg_tgt]
                        else:
                            yield [src, pos_tgt]
Z
add ssr  
zhangwenhui03 已提交
83

D
dongdaxiang 已提交
84 85 86 87
        return reader

    def train(self, file_list):
        return self._reader_creator(file_list, True)
Z
add ssr  
zhangwenhui03 已提交
88

D
dongdaxiang 已提交
89 90
    def test(self, file_list):
        return self._reader_creator(file_list, False)