未验证 提交 a7d57d7a 编写于 作者: S shiyu22 提交者: GitHub

add Tanimoto ground truth (#1138)

* add milvus ground truth

* add milvus groundtruth

* [skip ci] add milvus ground truth

* [skip ci]add tanimoto ground truth
上级 003b0cc2
...@@ -5,28 +5,30 @@ ...@@ -5,28 +5,30 @@
### Parameter Description: ### Parameter Description:
| parameter | description | default setting | | parameter | description | default setting |
| ------------------ | --------------------------------------------- | -------------------- | | ------------------ | --------------------------------------------- | ----------------------- |
| GET_VEC | whether to save feature vectors | False | | PROCESS_NUM | number of processes | 12 |
| PROCESS_NUM | number of processes | 5 | | GET_VEC | whether to save feature vectors | False |
| IP | whether metric_type is IP | True | | CSV | whether the query vector file format is csv | False |
| L2 | whether metric_type is L2 | False | | UINT8 | whether the query vector data format is uint8 | False |
| CSV | whether the query vector file format is csv | False | | BASE_FOLDER_NAME | path to the source vector dataset | '/data/milvus/base' |
| UINT8 | whether the query vector data format is uint8 | False | | NQ_FOLDER_NAME | path to the query vector dataset | '/data/milvus/query' |
| BASE_FOLDER_NAME | path to the source vector dataset | '/data/milvus/base' | | GT_ALL_FOLDER_NAME | intermediate filename | 'ground_truth_all' |
| NQ_FOLDER_NAME | path to the query vector dataset | '/data/milvus/query' | | GT_FOLDER_NAME | path saved the ground truth results | 'ground_truth' |
| GT_ALL_FOLDER_NAME | intermediate filename | 'ground_truth_all' | | LOC_FILE_NAME | file saved the gorund truth's location info | 'ground_truth.txt' |
| GT_FOLDER_NAME | path saved the ground truth results | 'ground_truth' | | FLOC_FILE_NAME | file saved the gorund truth's filenames info | 'file_ground_truth.txt' |
| LOC_FILE_NAME | file saved the gorund truth's location info | 'location.txt' | | VEC_FILE_NAME | file saved the gorund truth's feature vectors | 'vectors.npy' |
| FLOC_FILE_NAME | file saved the gorund truth's filenames info | 'file_location.txt' |
| VEC_FILE_NAME | file saved the gorund truth's feature vectors | 'vectors.npy' |
### Usage: ### Usage:
```bash ```bash
$ python3 milvus_ground_truth.py [-q <nq_num>] -k <topk_num> -l $ python3 milvus_ground_truth.py [-q <nq_num>] -k <topk_num> -m <metric typr>-l
# -q or --nq points the number of vectors taken from the query vector set. This parameter is optional, Without it will take all the data in the query set. # -q or --nq points the number of vectors taken from the query vector set. This parameter is optional, Without it will take all the data in the query set.
# -k or --topk points calculate the top k similar vectors. # -k or --topk points calculate the top k similar vectors.
# -l means generate the ground truth results, it will save in GT_FOLDER_NAME.In this path, LOC_FILE_NAME saved the gorund truth's location info, such as "8002005210",the first ‘8’ is meaningless, the 2-4th position means the position of the result file in the folder, the 5-10th position means the position of the result vector in the result file. The result filename and vector location saved in FLOC_FILE_NAME, such as "binary_128d_00000.npy 81759", and the result vector is saved in VEC_FILE_NAME.
# -m or --metric points the method vector distances are compared in Milvus,such as IP/L2/Tan.
# -l means generate the ground truth results, it will save in GT_FOLDER_NAME.In this path, LOC_FILE_NAME saved the gorund truth's results info, such as "8002005210",the first ‘8’ is meaningless, the 2-4th position means the position of the result file in the folder, the 5-10th position means the position of the result vector in the result file. The result filename and vector location saved in FLOC_FILE_NAME, such as "binary_128d_00000.npy 81759", and the result vector is saved in VEC_FILE_NAME.
``` ```
\ No newline at end of file
...@@ -2,14 +2,13 @@ import getopt ...@@ -2,14 +2,13 @@ import getopt
import os import os
import sys import sys
import time import time
from multiprocessing import Process from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
import numpy as np import numpy as np
import math
PROCESS_NUM = 12
GET_VEC = False GET_VEC = False
PROCESS_NUM = 5
IP = True
L2 = False
CSV = False CSV = False
UINT8 = False UINT8 = False
...@@ -18,31 +17,24 @@ NQ_FOLDER_NAME = '/data/milvus/query' ...@@ -18,31 +17,24 @@ NQ_FOLDER_NAME = '/data/milvus/query'
GT_ALL_FOLDER_NAME = 'ground_truth_all' GT_ALL_FOLDER_NAME = 'ground_truth_all'
GT_FOLDER_NAME = 'ground_truth' GT_FOLDER_NAME = 'ground_truth'
LOC_FILE_NAME = 'location.txt' LOC_FILE_NAME = 'ground_truth.txt'
FLOC_FILE_NAME = 'file_location.txt' FLOC_FILE_NAME = 'file_ground_truth.txt'
VEC_FILE_NAME = 'vectors.npy' VEC_FILE_NAME = 'vectors.npy'
# get vectors of the files # get vectors of the files
def load_nq_vec(nq): def load_query_vec(nq, vectors=[], length=0):
vectors = []
length = 0
filenames = os.listdir(NQ_FOLDER_NAME) filenames = os.listdir(NQ_FOLDER_NAME)
filenames.sort() filenames.sort()
if nq == 0:
for filename in filenames:
vectors += load_vec_list(NQ_FOLDER_NAME + '/' + filename)
return vectors
for filename in filenames: for filename in filenames:
vec_list = load_vec_list(NQ_FOLDER_NAME + '/' + filename) vec_list = load_vec_list(NQ_FOLDER_NAME + '/' + filename)
length += len(vec_list) length += len(vec_list)
if length > nq: if nq!=0 and length>nq :
num = nq % len(vec_list) num = nq % len(vec_list)
# vec_list = load_vec_list(NQ_FOLDER_NAME + '/' + filename, num) vectors += vec_list[0:num]
vec_list = vec_list[0:num] break
vectors += vec_list vectors += vec_list
if len(vectors) == nq: return vectors
return vectors
# load vectors from filr_name and num means nq's number # load vectors from filr_name and num means nq's number
...@@ -55,12 +47,20 @@ def load_vec_list(file_name): ...@@ -55,12 +47,20 @@ def load_vec_list(file_name):
data = np.load(file_name) data = np.load(file_name)
if UINT8: if UINT8:
data = (data + 0.5) / 255 data = (data + 0.5) / 255
vec_list = [] vec_list = data.tolist()
for i in range(len(data)):
vec_list.append(data[i].tolist())
return vec_list return vec_list
def hex_to_bin(fp):
vec=[]
length = len(fp) * 4
bstr = str(bin(int(fp,16)))
bstr = (length-(len(bstr)-2)) * '0' + bstr[2:]
for f in bstr:
vec.append(int(f))
return vec
def calEuclideanDistance(vec1, vec2): def calEuclideanDistance(vec1, vec2):
vec1 = np.array(vec1) vec1 = np.array(vec1)
vec2 = np.array(vec2) vec2 = np.array(vec2)
...@@ -75,6 +75,18 @@ def calInnerDistance(vec1, vec2): ...@@ -75,6 +75,18 @@ def calInnerDistance(vec1, vec2):
return dist return dist
def calTanimoto(vec1, vec2):
vec1 = hex_to_bin(vec1)
vec2 = hex_to_bin(vec2)
# print(vec1,vec2)
nc = float(np.inner(vec1, vec2))
n1 = float(np.sum(vec1))
n2 = float(np.sum(vec2))
dist = nc/(n1+n2-nc)
print(nc,n1,n2)
return dist
def get_ground_truth_l2(topk, idx, vct_nq): def get_ground_truth_l2(topk, idx, vct_nq):
filenames = os.listdir(BASE_FOLDER_NAME) filenames = os.listdir(BASE_FOLDER_NAME)
filenames.sort() filenames.sort()
...@@ -85,7 +97,7 @@ def get_ground_truth_l2(topk, idx, vct_nq): ...@@ -85,7 +97,7 @@ def get_ground_truth_l2(topk, idx, vct_nq):
for j in range(len(vec_list)): for j in range(len(vec_list)):
dist = calEuclideanDistance(vct_nq, vec_list[j]) dist = calEuclideanDistance(vct_nq, vec_list[j])
num_j = "%01d%03d%06d" % (8, k, j) num_j = "%01d%03d%06d" % (8, k, j)
if j < topk and k == 0: if k==0 and j<topk :
no_dist[num_j] = dist no_dist[num_j] = dist
else: else:
# sorted by values # sorted by values
...@@ -94,8 +106,9 @@ def get_ground_truth_l2(topk, idx, vct_nq): ...@@ -94,8 +106,9 @@ def get_ground_truth_l2(topk, idx, vct_nq):
if dist < max_value: if dist < max_value:
m = no_dist.pop(max_key) m = no_dist.pop(max_key)
no_dist[num_j] = dist no_dist[num_j] = dist
k = k + 1 k += 1
no_dist = sorted(no_dist.items(), key=lambda x: x[1]) no_dist = sorted(no_dist.items(), key=lambda x: x[1])
print(no_dist)
save_gt_file(no_dist, idx) save_gt_file(no_dist, idx)
...@@ -108,8 +121,8 @@ def get_ground_truth_ip(topk, idx, vct_nq): ...@@ -108,8 +121,8 @@ def get_ground_truth_ip(topk, idx, vct_nq):
vec_list = load_vec_list(BASE_FOLDER_NAME + '/' + filename) vec_list = load_vec_list(BASE_FOLDER_NAME + '/' + filename)
for j in range(len(vec_list)): for j in range(len(vec_list)):
dist = calInnerDistance(vct_nq, vec_list[j]) dist = calInnerDistance(vct_nq, vec_list[j])
num_j = "%01d%03d%06d" % (8, k, j) num_j = "%03d%06d" % (k, j)
if j < topk and k == 0: if k==0 and j<topk :
no_dist[num_j] = dist no_dist[num_j] = dist
else: else:
min_key = min(no_dist, key=no_dist.get) min_key = min(no_dist, key=no_dist.get)
...@@ -117,55 +130,70 @@ def get_ground_truth_ip(topk, idx, vct_nq): ...@@ -117,55 +130,70 @@ def get_ground_truth_ip(topk, idx, vct_nq):
if dist > min_value: if dist > min_value:
m = no_dist.pop(min_key) m = no_dist.pop(min_key)
no_dist[num_j] = dist no_dist[num_j] = dist
k = k + 1 k += 1
no_dist = sorted(no_dist.items(), key=lambda x: x[1], reverse=True) no_dist = sorted(no_dist.items(), key=lambda x: x[1], reverse=True)
print(no_dist)
save_gt_file(no_dist, idx)
def get_ground_truth_tanimoto(topk, idx, vec_nq):
filenames = os.listdir(BASE_FOLDER_NAME) # get the whole file names
filenames.sort()
no_dist = {}
k = 0
for filename in filenames:
vec_list = load_vec_list(BASE_FOLDER_NAME + '/' + filename)
print(BASE_FOLDER_NAME + '/' + filename, len(vec_list))
for j in range(len(vec_list)):
dist = calTanimoto(vec_nq, vec_list[j])
num_j = "%03d%06d" % (k, j)
if k==0 and j<topk :
no_dist[num_j] = dist
else:
min_key = min(no_dist, key=no_dist.get)
min_value = no_dist[min_key]
if dist > min_value:
m = no_dist.pop(min_key)
no_dist[num_j] = dist
k += 1
no_dist = sorted(no_dist.items(), key=lambda x: x[1], reverse=True)
print(no_dist)
save_gt_file(no_dist, idx) save_gt_file(no_dist, idx)
def save_gt_file(no_dist, idx): def save_gt_file(no_dist, idx):
s = "%05d" % idx filename = "%05d" % idx + 'results.txt'
idx_fname = GT_ALL_FOLDER_NAME + '/' + s + '_idx.txt' with open(GT_ALL_FOLDER_NAME+'/'+filename, 'w') as f:
dis_fname = GT_ALL_FOLDER_NAME + '/' + s + '_dis.txt'
with open(idx_fname, 'w') as f:
for re in no_dist: for re in no_dist:
f.write(str(re[0]) + '\n') f.write(str(re[0]) + ' ' + str(re[1]) + '\n')
f.write('\n')
with open(dis_fname, 'w') as f:
for re in no_dist:
f.write(str(re[1]) + '\n')
f.write('\n')
def get_loc_txt(file): def get_loc_txt(file):
filenames = os.listdir(GT_ALL_FOLDER_NAME) filenames = os.listdir(GT_ALL_FOLDER_NAME)
filenames.sort() filenames.sort()
write_file = open(GT_FOLDER_NAME + '/' + file, 'w+') write_file = open(GT_FOLDER_NAME + '/' + file, 'w+')
for f in filenames: for f in filenames:
if f.endswith('_idx.txt'): for line in open(GT_ALL_FOLDER_NAME+'/'+f, 'r'):
f = GT_ALL_FOLDER_NAME + '/' + f write_file.write(line)
for line in open(f, 'r'): write_file.write('\n')
write_file.write(line)
def get_file_loc_txt(gt_file, fnames_file): def get_file_loc_txt(gt_file, fnames_file):
gt_file = GT_FOLDER_NAME + '/' + gt_file
fnames_file = GT_FOLDER_NAME + '/' + fnames_file
filenames = os.listdir(BASE_FOLDER_NAME) filenames = os.listdir(BASE_FOLDER_NAME)
filenames.sort() filenames.sort()
with open(gt_file, 'r') as gt_f: with open(GT_FOLDER_NAME+'/'+gt_file, 'r') as gt_f:
with open(fnames_file, 'w') as fnames_f: with open(GT_FOLDER_NAME+'/'+fnames_file, 'w') as fnames_f:
for line in gt_f: for line in gt_f:
if line != '\n': if line != '\n':
line = line.strip() line = line.split()[0]
loca = int(line[1:4]) loca = int(line[1:4])
offset = int(line[4:10]) offset = int(line[4:10])
fnames_f.write(filenames[loca] + ' ' + str(offset + 1) + '\n') fnames_f.write(filenames[loca] + ' ' + str(offset + 1) + '\n')
else: else:
fnames_f.write('\n') fnames_f.write(line)
def load_gt_file_out(): def load_gt_file_out():
file_name = GT_FOLDER_NAME + '/' +FLOC_FILE_NAME file_name = GT_FOLDER_NAME + '/' + FLOC_FILE_NAME
base_filename = [] base_filename = []
num = [] num = []
with open(file_name, 'r') as f: with open(file_name, 'r') as f:
...@@ -177,70 +205,36 @@ def load_gt_file_out(): ...@@ -177,70 +205,36 @@ def load_gt_file_out():
return base_filename, num return base_filename, num
def ground_truth_process(nq, topk): def ground_truth_process(metric,nq_list, topk, num):
try: thread_num = len(nq_list)
os.mkdir(GT_ALL_FOLDER_NAME) with ProcessPoolExecutor(thread_num) as executor:
except: for i in range(thread_num):
print('there already exits folder named ' + GT_ALL_FOLDER_NAME + ', please delete it first.') # print("Process:",num+i)
else: if metric == 'L2':
vectors = load_nq_vec(nq) executor.submit(get_ground_truth_l2, topk, num+i, nq_list[i])
print("query list:", len(vectors)) elif metric == 'IP':
processes = [] executor.submit(get_ground_truth_ip, topk, num+i, nq_list[i])
process_num = PROCESS_NUM elif metric == 'Tan':
nq = len(vectors) executor.submit(get_ground_truth_tanimoto, topk, num+i, nq_list[i])
loops = nq // process_num get_loc_txt(LOC_FILE_NAME)
rest = nq % process_num get_file_loc_txt(LOC_FILE_NAME, FLOC_FILE_NAME)
if rest != 0: if GET_VEC:
loops += 1 vec = []
time_start = time.time() file, num = load_gt_file_out()
for loop in range(loops): for i in range(len(file)):
time1_start = time.time() n = int(num[i]) - 1
base = loop * process_num vectors = load_vec_list(BASE_FOLDER_NAME + '/' + file[i])
if rest != 0 and loop == loops - 1: vec.append(vectors[n])
process_num = rest print("saved len of vec:", len(vec))
print('base:', loop) np.save(GT_FOLDER_NAME + '/' + VEC_FILE_NAME, vec)
for i in range(process_num):
print('nq_index:', base + i)
if L2:
if base + i == 0:
print("get ground truth by L2.")
process = Process(target=get_ground_truth_l2,
args=(topk, base + i, vectors[base + i]))
elif IP:
if base + i == 0:
print("get ground truth by IP.")
process = Process(target=get_ground_truth_ip,
args=(topk, base + i, vectors[base + i]))
processes.append(process)
process.start()
for p in processes:
p.join()
time1_end = time.time()
print("base", loop, "time_cost = ", round(time1_end - time1_start, 4))
if not os.path.exists(GT_FOLDER_NAME):
os.mkdir(GT_FOLDER_NAME)
get_loc_txt(LOC_FILE_NAME)
get_file_loc_txt(LOC_FILE_NAME, FLOC_FILE_NAME)
if GET_VEC:
vec = []
file, num = load_gt_file_out()
for i in range(len(file)):
n = int(num[i]) - 1
vectors = load_vec_list(BASE_FOLDER_NAME + '/' + file[i])
vec.append(vectors[n])
print("saved len of vec:", len(vec))
np.save(GT_FOLDER_NAME + '/' + VEC_FILE_NAME, vec)
time_end = time.time()
time_cost = time_end - time_start
print("total_time = ", round(time_cost, 4), "\nGet the ground truth successfully!")
def main(): def main():
try: try:
opts, args = getopt.getopt( opts, args = getopt.getopt(
sys.argv[1:], sys.argv[1:],
"hlq:k:", "hlq:k:m:",
["help", "nq=", "topk="], ["help", "nq=", "topk=", "metric="],
) )
except getopt.GetoptError: except getopt.GetoptError:
print("Usage: test.py [-q <nq>] -k <topk> -s") print("Usage: test.py [-q <nq>] -k <topk> -s")
...@@ -254,9 +248,33 @@ def main(): ...@@ -254,9 +248,33 @@ def main():
nq = int(opt_value) nq = int(opt_value)
elif opt_name in ("-k", "--topk"): elif opt_name in ("-k", "--topk"):
topk = int(opt_value) topk = int(opt_value)
elif opt_name == "-l": elif opt_name in ("-m", "--metric"):
ground_truth_process(nq, topk) # test.py [-q <nq>] -k <topk> -l metric = opt_value
sys.exit() elif opt_name == "-l": # test.py [-q <nq>] -k <topk> -m -l
try:
os.mkdir(GT_ALL_FOLDER_NAME)
except:
print('there already exits folder named ' + GT_ALL_FOLDER_NAME + ', please delete it first.')
sys.exit()
if not os.path.exists(GT_FOLDER_NAME):
os.mkdir(GT_FOLDER_NAME)
print("metric type is",metric)
time_start = time.time()
query_vectors = load_query_vec(nq)
nq = len(query_vectors)
print("query list:", len(query_vectors))
num = math.ceil(nq/PROCESS_NUM)
for i in range(num):
print("start with round:",i+1)
if i==num-1:
ground_truth_process(metric, query_vectors[i*PROCESS_NUM:nq], topk, i*PROCESS_NUM)
else:
ground_truth_process(metric, query_vectors[i*PROCESS_NUM:i*PROCESS_NUM+PROCESS_NUM], topk, i*PROCESS_NUM)
time_end = time.time()
time_cost = time_end - time_start
print("total_time = ", round(time_cost, 4), "\nGet the ground truth successfully!")
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册