cal_pos_neg.py 2.8 KB
Newer Older
Y
yinhaofeng 已提交
1
#encoding=utf-8     
Y
change  
yinhaofeng 已提交
2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Y
yinhaofeng 已提交
3
#
Y
change  
yinhaofeng 已提交
4 5 6 7 8 9 10 11 12 13 14
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
add  
yinhaofeng 已提交
15 16 17 18 19 20 21 22
"""
docstring
"""

import os
import sys

if len(sys.argv) < 2:
Y
yinhaofeng 已提交
23
    print("usage:python {} input".format(sys.argv[0]))
Y
add  
yinhaofeng 已提交
24 25
    sys.exit(-1)

Y
yinhaofeng 已提交
26
fin = open(sys.argv[1])
Y
add  
yinhaofeng 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
pos_num = 0
neg_num = 0

score_list = []
label_list = []
last_query = "-1"

#0       12.786960       1
#0       -1.480890       0
cnt = 0
query_num = 0
pair_num = 0
equal_num = 0
for line in fin:
    cols = line.strip().split("\t")
    cnt += 1
    if cnt % 500000 == 0:
Y
yinhaofeng 已提交
44
        print("cnt:{}".format(1.0 * pos_num / neg_num))
Y
add  
yinhaofeng 已提交
45 46 47 48 49 50
    if len(cols) != 3:
        continue

    cur_query = cols[0]
    if cur_query != last_query:
        query_num += 1
Y
yinhaofeng 已提交
51 52
        for i in range(0, len(score_list)):
            for j in range(i + 1, len(score_list)):
Y
add  
yinhaofeng 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
                if label_list[i] == label_list[j]:
                    continue
                pair_num += 1
                if (score_list[i] - score_list[j]) * (
                        label_list[i] - label_list[j]) < 0:
                    neg_num += 1
                elif (score_list[i] - score_list[j]) * (
                        label_list[i] - label_list[j]) > 0:
                    pos_num += 1
                else:
                    equal_num += 1
        score_list = []
        label_list = []

    last_query = cur_query

    label = int(cols[2])

    score_list.append(round(float(cols[1]), 6))
    label_list.append(int(cols[2]))

fin.close()

Y
yinhaofeng 已提交
76 77
for i in range(0, len(score_list)):
    for j in range(i + 1, len(score_list)):
Y
add  
yinhaofeng 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90
        if label_list[i] == label_list[j]:
            continue
        pair_num += 1
        if (score_list[i] - score_list[j]) * (label_list[i] - label_list[j]
                                              ) < 0:
            neg_num += 1
        elif (score_list[i] - score_list[j]) * (label_list[i] - label_list[j]
                                                ) > 0:
            pos_num += 1
        else:
            equal_num += 1

if neg_num > 0:
Y
yinhaofeng 已提交
91 92 93 94 95 96
    print("pnr:{}".format(1.0 * pos_num / neg_num))
    print("query_num:{}".format(query_num))
    print("pair_num:{} , {}".format(pos_num + neg_num + equal_num, pair_num))
    print("equal_num:{}".format(equal_num))
    print("正序率: {}".format(1.0 * pos_num / (pos_num + neg_num)))
print("pos_num: {} , neg_num: {}".format(pos_num, neg_num))