# -*- coding:utf8 -*- # ============================================================================== # Copyright 2017 Baidu.com, Inc. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ This module implements the Vocab class for converting string to id and back """ import numpy as np class Vocab(object): """ Implements a vocabulary to store the tokens in the data, with their corresponding embeddings. """ def __init__(self, filename=None, initial_tokens=None, lower=False): self.id2token = {} self.token2id = {} self.token_cnt = {} self.lower = lower self.embed_dim = None self.embeddings = None self.pad_token = '' self.unk_token = '' self.initial_tokens = initial_tokens if initial_tokens is not None else [] self.initial_tokens.extend([self.pad_token, self.unk_token]) for token in self.initial_tokens: self.add(token) if filename is not None: self.load_from_file(filename) def size(self): """ get the size of vocabulary Returns: an integer indicating the size """ return len(self.id2token) def load_from_file(self, file_path): """ loads the vocab from file_path Args: file_path: a file with a word in each line """ for line in open(file_path, 'r'): token = line.rstrip('\n') self.add(token) def get_id(self, token): """ gets the id of a token, returns the id of unk token if token is not in vocab Args: key: a string indicating the word Returns: an integer """ token = token.lower() if self.lower else token try: return self.token2id[token] except KeyError: return self.token2id[self.unk_token] def get_token(self, idx): """ gets the token corresponding to idx, returns unk token if idx is not in vocab Args: idx: an integer returns: a token string """ try: return self.id2token[idx] except KeyError: return self.unk_token def add(self, token, cnt=1): """ adds the token to vocab Args: token: a string cnt: a num indicating the count of the token to add, default is 1 """ token = token.lower() if self.lower else token if token in self.token2id: idx = self.token2id[token] else: idx = len(self.id2token) self.id2token[idx] = token self.token2id[token] = idx if cnt > 0: if token in self.token_cnt: self.token_cnt[token] += cnt else: self.token_cnt[token] = cnt return idx def filter_tokens_by_cnt(self, min_cnt): """ filter the tokens in vocab by their count Args: min_cnt: tokens with frequency less than min_cnt is filtered """ filtered_tokens = [ token for token in self.token2id if self.token_cnt[token] >= min_cnt ] # rebuild the token x id map self.token2id = {} self.id2token = {} for token in self.initial_tokens: self.add(token, cnt=0) for token in filtered_tokens: self.add(token, cnt=0) def randomly_init_embeddings(self, embed_dim): """ randomly initializes the embeddings for each token Args: embed_dim: the size of the embedding for each token """ self.embed_dim = embed_dim self.embeddings = np.random.rand(self.size(), embed_dim) for token in [self.pad_token, self.unk_token]: self.embeddings[self.get_id(token)] = np.zeros([self.embed_dim]) def load_pretrained_embeddings(self, embedding_path): """ loads the pretrained embeddings from embedding_path, tokens not in pretrained embeddings will be filtered Args: embedding_path: the path of the pretrained embedding file """ trained_embeddings = {} with open(embedding_path, 'r') as fin: for line in fin: contents = line.strip().split() token = contents[0].decode('utf8') if token not in self.token2id: continue trained_embeddings[token] = list(map(float, contents[1:])) if self.embed_dim is None: self.embed_dim = len(contents) - 1 filtered_tokens = trained_embeddings.keys() # rebuild the token x id map self.token2id = {} self.id2token = {} for token in self.initial_tokens: self.add(token, cnt=0) for token in filtered_tokens: self.add(token, cnt=0) # load embeddings self.embeddings = np.zeros([self.size(), self.embed_dim]) for token in self.token2id.keys(): if token in trained_embeddings: self.embeddings[self.get_id(token)] = trained_embeddings[token] def convert_to_ids(self, tokens): """ Convert a list of tokens to ids, use unk_token if the token is not in vocab. Args: tokens: a list of token Returns: a list of ids """ vec = [self.get_id(label) for label in tokens] return vec def recover_from_ids(self, ids, stop_id=None): """ Convert a list of ids to tokens, stop converting if the stop_id is encountered Args: ids: a list of ids to convert stop_id: the stop id, default is None Returns: a list of tokens """ tokens = [] for i in ids: tokens += [self.get_token(i)] if stop_id is not None and i == stop_id: break return tokens