get_recall.py 2.6 KB
Newer Older
T
tangjiji 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
import sys

ans_dict = {}
text_ans_dict = {}

filename = './data/flickr/flickr.dev.data'
with open(filename) as f:
    for line in f:
        line = line.strip().split('\t')
        image_id, sent_id = line[0], line[1]
        ans_dict[sent_id.strip(' ')] = image_id.strip(' ')
        text_ans_dict.setdefault(image_id.strip(' '), [])
        text_ans_dict[image_id.strip(' ')].append(sent_id.strip(' '))

if len(sys.argv) > 1:
    res_file = sys.argv[1]
else:
    res_file = "./result"
print ('=============== IMAGE RETRIEVAL ==================')
with open(res_file) as f:
    r1, r5, r10 = 0, 0, 0
    cnt = 0
    res_dict = {}
    text_res_dict = {}
    idx_all = 0.0
    for line in f:
        line = line.strip().split('\t')
        if len(line) != 3:
            break
        score, image_id, sent_id = float(line[0]), line[1], line[2]
        res_dict.setdefault(sent_id, [])
        res_dict[sent_id].append((score, image_id))
        text_res_dict.setdefault(image_id, [])
        text_res_dict[image_id].append((score, sent_id))
        if len(res_dict[sent_id]) == 1000:
            res_list = res_dict[sent_id]
            res_list = sorted(res_list, reverse = True)
            ans = ans_dict[sent_id]
            image_id_sort = list(zip(*res_list)[1])
            ans_idx = image_id_sort.index(ans.strip())
            if ans_idx < 1:
                r1 += 1.0
            if ans_idx < 5:
                r5 += 1.0
            if ans_idx < 10:
                r10 += 1.0
            idx_all += (ans_idx + 1)
            cnt += 1
            if cnt %  100 == 0:
                print cnt, round(r1/cnt, 4), round(r5/cnt, 4), round(r10/cnt, 4), round(idx_all/cnt, 4)
    print '-----------------------------'
    print "instance %d r1:%.4f, r5:%.4f, r10:%.4f, avg_rank:%.4f" % (cnt, r1/cnt, r5/cnt, r10/cnt, idx_all/cnt)

print ('\n=============== TEXT RETRIEVAL ==================')
cnt = 0
r1, r5, r10 = 0, 0, 0
idx_all = 0.0
for image_id in text_res_dict:
    res_list = text_res_dict[image_id]
    res_list = sorted(res_list, reverse = True)
    ans = text_ans_dict[image_id]
    text_id_sort = list(zip(*res_list)[1])
    ans_idx_all = []
    for item in ans: 
        ans_idx_all.append(text_id_sort.index(item.strip()))
    ans_idx = min(ans_idx_all)
    if ans_idx < 1:
        r1 += 1.0
    if ans_idx < 5:
        r5 += 1.0
    if ans_idx < 10:
        r10 += 1.0
    idx_all += (ans_idx + 1)
    cnt += 1
    if cnt % 500 == 0:
        print cnt, round(r1/cnt, 4), round(r5/cnt, 4), round(r10/cnt, 4), round(idx_all/cnt, 4)

print '-----------------------------'
print "instance %d r1:%.4f, r5:%.4f, r10:%.4f, avg_rank:%.4f" % (cnt, r1/cnt, r5/cnt, r10/cnt, idx_all/cnt)