# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import io class SentaReader(): def __init__(self, vocab_path, max_seq_len=20): self.max_seq_len = max_seq_len self.word_dict = self.load_vocab(vocab_path) def load_vocab(self, vocab_path): """ load the given vocabulary """ vocab = {} with io.open(vocab_path, 'r', encoding='utf8') as f: wid = 0 for line in f: if line.strip() not in vocab: vocab[line.strip()] = wid wid += 1 vocab[""] = len(vocab) return vocab def process(self, cols): unk_id = len(self.word_dict) pad_id = 0 wids = [ self.word_dict[x] if x in self.word_dict else unk_id for x in cols ] seq_len = len(wids) if seq_len < self.max_seq_len: for i in range(self.max_seq_len - seq_len): wids.append(pad_id) else: wids = wids[:self.max_seq_len] seq_len = self.max_seq_len return wids