From 876761025f800104185d30b7ced3f9d2884bbfe2 Mon Sep 17 00:00:00 2001 From: yinhaofeng <1841837261@qq.com> Date: Fri, 21 Aug 2020 14:45:18 +0000 Subject: [PATCH] add --- tools/cal_pos_neg.py | 90 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tools/cal_pos_neg.py diff --git a/tools/cal_pos_neg.py b/tools/cal_pos_neg.py new file mode 100644 index 00000000..8b356046 --- /dev/null +++ b/tools/cal_pos_neg.py @@ -0,0 +1,90 @@ +#!/usr/bin/python +#-*- coding:utf-8 -*- +############################ +#File Name: cal_pos_neg.py +#Author: youqiheng +#Mail: youqiheng@baidu.com +#Created Time: 2018-04-15 21:59:45 +############################ +""" +docstring +""" + +import os +import sys + +if len(sys.argv) < 2: + print "usage:python %s input" % (sys.argv[0]) + sys.exit(-1) + +fin = file(sys.argv[1]) +pos_num = 0 +neg_num = 0 + +score_list = [] +label_list = [] +last_query = "-1" + +#0 12.786960 1 +#0 -1.480890 0 +cnt = 0 +query_num = 0 +pair_num = 0 +equal_num = 0 +for line in fin: + cols = line.strip().split("\t") + cnt += 1 + if cnt % 500000 == 0: + print "cnt:", cnt, 1.0 * pos_num / neg_num + if len(cols) != 3: + continue + + cur_query = cols[0] + if cur_query != last_query: + query_num += 1 + for i in xrange(0, len(score_list)): + for j in xrange(i + 1, len(score_list)): + if label_list[i] == label_list[j]: + continue + pair_num += 1 + if (score_list[i] - score_list[j]) * ( + label_list[i] - label_list[j]) < 0: + neg_num += 1 + elif (score_list[i] - score_list[j]) * ( + label_list[i] - label_list[j]) > 0: + pos_num += 1 + else: + equal_num += 1 + score_list = [] + label_list = [] + + last_query = cur_query + + label = int(cols[2]) + + score_list.append(round(float(cols[1]), 6)) + label_list.append(int(cols[2])) + +fin.close() + +for i in xrange(0, len(score_list)): + for j in xrange(i + 1, len(score_list)): + if label_list[i] == label_list[j]: + continue + pair_num += 1 + if (score_list[i] - score_list[j]) * (label_list[i] - label_list[j] + ) < 0: + neg_num += 1 + elif (score_list[i] - score_list[j]) * (label_list[i] - label_list[j] + ) > 0: + pos_num += 1 + else: + equal_num += 1 + +if neg_num > 0: + print "pnr:", 1.0 * pos_num / neg_num + print "query_num:", query_num + print "pair_num:", pos_num + neg_num + equal_num, pair_num + print "equal_num:", equal_num + print "正序率:", 1.0 * pos_num / (pos_num + neg_num) +print pos_num, neg_num -- GitLab