w2v_evaluate_reader.py 3.0 KB
Newer Older
T
tangwei 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
M
add w2v  
malin10 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
T
tangwei 已提交
14

M
add w2v  
malin10 已提交
15
import io
T
tangwei 已提交
16

M
add w2v  
malin10 已提交
17
import six
T
tangwei 已提交
18

C
Chengmo 已提交
19
from paddlerec.core.reader import ReaderBase
20
from paddlerec.core.utils import envs
M
add w2v  
malin10 已提交
21 22


C
Chengmo 已提交
23
class Reader(ReaderBase):
M
add w2v  
malin10 已提交
24
    def init(self):
M
malin10 已提交
25 26
        dict_path = envs.get_global_env(
            "dataset.dataset_infer.word_id_dict_path")
M
add w2v  
malin10 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
        self.word_to_id = dict()
        self.id_to_word = dict()
        with io.open(dict_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.word_to_id[line.split(' ')[0]] = int(line.split(' ')[1])
                self.id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]
        self.dict_size = len(self.word_to_id)

    def native_to_unicode(self, s):
        if self._is_unicode(s):
            return s
        try:
            return self._to_unicode(s)
        except UnicodeDecodeError:
            res = self._to_unicode(s, ignore_errors=True)
            return res

    def _is_unicode(self, s):
        if six.PY2:
            if isinstance(s, unicode):
                return True
        else:
            if isinstance(s, str):
                return True
        return False
T
for mat  
tangwei 已提交
52

M
add w2v  
malin10 已提交
53 54 55 56 57
    def _to_unicode(self, s, ignore_errors=False):
        if self._is_unicode(s):
            return s
        error_mode = "ignore" if ignore_errors else "strict"
        return s.decode("utf-8", errors=error_mode)
T
for mat  
tangwei 已提交
58

M
add w2v  
malin10 已提交
59 60
    def strip_lines(self, line, vocab):
        return self._replace_oov(vocab, self.native_to_unicode(line))
T
for mat  
tangwei 已提交
61

M
add w2v  
malin10 已提交
62 63 64 65 66 67 68 69 70 71
    def _replace_oov(self, original_vocab, line):
        """Replace out-of-vocab words with "<UNK>".
      This maintains compatibility with published results.
      Args:
        original_vocab: a set of strings (The standard vocabulary for the dataset)
        line: a unicode string - a space-delimited sequence of words.
      Returns:
        a unicode string - a space-delimited sequence of words.
      """
        return u" ".join([
T
tangwei 已提交
72 73
            word if word in original_vocab else u"<UNK>"
            for word in line.split()
M
add w2v  
malin10 已提交
74 75 76 77
        ])

    def generate_sample(self, line):
        def reader():
M
malin10 已提交
78 79
            if ':' in line:
                pass
M
add w2v  
malin10 已提交
80 81
            features = self.strip_lines(line.lower(), self.word_to_id)
            features = features.split()
T
tangwei 已提交
82 83 84 85
            yield [('analogy_a', [self.word_to_id[features[0]]]),
                   ('analogy_b', [self.word_to_id[features[1]]]),
                   ('analogy_c', [self.word_to_id[features[2]]]),
                   ('analogy_d', [self.word_to_id[features[3]]])]
T
for mat  
tangwei 已提交
86

M
add w2v  
malin10 已提交
87
        return reader