calculate_dis.py 2.2 KB
Newer Older
Z
zhangruiqing01 已提交
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangruiqing01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
Z
zhangruiqing01 已提交
16
    python calculate_dis.py DICTIONARYTXT FEATURETXT
Z
zhangruiqing01 已提交
17 18 19 20 21 22 23 24 25

Required arguments:
    DICTIONARYTXT    the dictionary generated in dataprovider
    FEATURETXT       the text format word feature, one line for one word
"""

import numpy as np
from argparse import ArgumentParser

Z
zhangruiqing01 已提交
26

Z
zhangruiqing01 已提交
27 28 29 30 31
def load_dict(fdict):
    words = [line.strip() for line in fdict.readlines()]
    dictionary = dict(zip(words, xrange(len(words))))
    return dictionary

Z
zhangruiqing01 已提交
32

Z
zhangruiqing01 已提交
33 34 35 36 37 38 39 40 41 42 43 44
def load_emb(femb):
    feaBank = []
    flag_firstline = True
    for line in femb:
        if flag_firstline:
            flag_firstline = False
            continue
        fea = np.array([float(x) for x in line.strip().split(',')])
        normfea = fea * 1.0 / np.linalg.norm(fea)
        feaBank.append(normfea)
    return feaBank

Z
zhangruiqing01 已提交
45

Z
zhangruiqing01 已提交
46 47 48 49 50
def calcos(id1, id2, Fea):
    f1 = Fea[id1]
    f2 = Fea[id2]
    return np.dot(f1.transpose(), f2)

Z
zhangruiqing01 已提交
51

Z
zhangruiqing01 已提交
52 53
def get_wordidx(w, Dict):
    if w not in Dict:
Z
zhangruiqing01 已提交
54
        print 'ERROR: %s not in the dictionary' % w
Z
zhangruiqing01 已提交
55 56 57
        return -1
    return Dict[w]

Z
zhangruiqing01 已提交
58

Z
zhangruiqing01 已提交
59 60
if __name__ == '__main__':
    parser = ArgumentParser()
Z
zhangruiqing01 已提交
61 62
    parser.add_argument('dict', help='dictionary file')
    parser.add_argument('fea', help='feature file')
Z
zhangruiqing01 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76
    args = parser.parse_args()

    with open(args.dict) as fdict:
        word_dict = load_dict(fdict)

    with open(args.fea) as ffea:
        word_fea = load_emb(ffea)

    while True:
        w1, w2 = raw_input("please input two words: ").split()
        w1_id = get_wordidx(w1, word_dict)
        w2_id = get_wordidx(w2, word_dict)
        if w1_id == -1 or w2_id == -1:
            continue
Z
zhangruiqing01 已提交
77
        print 'similarity: %s' % (calcos(w1_id, w2_id, word_fea))