From 2a184840e9b96c40b43871b35be51c2a001f5c25 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Thu, 11 Oct 2018 19:41:08 +0800 Subject: [PATCH] speed up loading pretrained word2vec --- .../pretrained_word2vec.py | 23 ++++++++++++++----- .../train_and_evaluate.py | 3 ++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/fluid/text_matching_on_quora/pretrained_word2vec.py b/fluid/text_matching_on_quora/pretrained_word2vec.py index 0239ff03..eda9e80a 100755 --- a/fluid/text_matching_on_quora/pretrained_word2vec.py +++ b/fluid/text_matching_on_quora/pretrained_word2vec.py @@ -4,22 +4,33 @@ This Module provide pretrained word-embeddings from __future__ import print_function import numpy as np +import time, datetime -def Glove840B_300D(filepath="data/glove.840B.300d.txt"): +def Glove840B_300D(filepath, keys=None): """ input: the "glove.840B.300d.txt" file path return: a dict, key: word (unicode), value: a numpy array with shape [300] """ + if keys is not None: + assert(isinstance(keys, set)) print("loading word2vec from ", filepath) + print("please wait for a minute.") + start = time.time() word2vec = {} + with open(filepath, "r") as f: - lines = f.readlines() - for line in lines: + for line in f: info = line.strip().split() - word, vector = info[0], info[1:] + # TODO: test python3 + word = info[0].decode('utf-8') + if (keys is not None) and (word not in keys): + continue + vector = info[1:] assert(len(vector) == 300) - #TODO: test python3 - word2vec[word.decode('utf-8')] = np.asarray(vector, dtype='float32') + word2vec[word] = np.asarray(vector, dtype='float32') + + end = time.time() + print("Spent ", str(datetime.timedelta(seconds=end-start)), " on loading word2vec.") return word2vec if __name__ == '__main__': diff --git a/fluid/text_matching_on_quora/train_and_evaluate.py b/fluid/text_matching_on_quora/train_and_evaluate.py index 556010dc..a2d97b42 100755 --- a/fluid/text_matching_on_quora/train_and_evaluate.py +++ b/fluid/text_matching_on_quora/train_and_evaluate.py @@ -219,7 +219,8 @@ def main(): # load pretrained_word_embedding if global_config.use_pretrained_word_embedding: - word2vec = Glove840B_300D(filepath=os.path.join(DATA_DIR, "glove.840B.300d.txt")) + word2vec = Glove840B_300D(filepath=os.path.join(DATA_DIR, "glove.840B.300d.txt"), + keys=set(word_dict.keys())) pretrained_word_embedding = utils.get_pretrained_word_embedding( word2vec=word2vec, word2id=word_dict, -- GitLab