未验证 提交 680c0c9e 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #761 from FredHuang16/develop_reg

add vector search
CXX=/usr/bin/g++-5
all : index
index.so : src/config.h src/graph.h src/data.h interface.cc
$(CXX) -shared -fPIC interface.cc -o index.so -std=c++11 -Ofast -march=native -g -flto -funroll-loops -DOMP -fopenmp
# 向量检索
## 简介
一些垂域识别任务(如车辆、商品等)需要识别的类别数较大,往往采用基于检索的方式,通过查询向量与底库向量进行快速的最近邻搜索,获得匹配的预测类别。向量检索模块提供基础的近似最近邻搜索算法,基于百度自研的Möbius算法,一种基于图的近似最近邻搜索算法,用于最大内积搜索 (MIPS)。 该模块提供python接口,支持numpy和 tensor类型向量,支持L2和Inner Product距离计算。
Mobius 算法细节详见论文 ([Möbius Transformation for Fast Inner Product Search on Graph](http://research.baidu.com/Public/uploads/5e189d36b5cf6.PDF), [Code](https://github.com/sunbelbd/mobius)
## 安装
若index.so不可用,在项目目录下运行以下命令生成新的index.so文件
make index.so
编译环境: g++ 5.4.0 , 9.3.0. 其他版本也可能工作。 请确保您的 C++ 编译器支持 C++11 标准。
## 快速使用
import numpy as np
from interface import Graph_Index
# 随机产生样本
index_vectors = np.random.rand(100000,128).astype(np.float32)
query_vector = np.random.rand(128).astype(np.float32)
index_docs = ["ID_"+str(i) for i in range(100000)]
# 初始化索引结构
indexer = Graph_Index(dist_type="IP") #支持"IP"和"L2"
indexer.build(gallery_vectors=index_vectors, gallery_docs=index_docs, pq_size=100, index_path='test')
# 查询
scores, docs = indexer.search(query=query_vector, return_k=10, search_budget=100)
print(scores)
print(docs)
# 保存与加载
indexer.dump(index_path="test")
indexer.load(index_path="test")
from .interface import Graph_Index
#MIT License
#
#Copyright (c) 2021 Mobius Authors
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
#from https://github.com/sunbelbd/mobius/blob/e2d166547d61d791da8f06747a63b9cd38f02c71/main.cc
#include<stdio.h>
#include<string.h>
#include <iostream>
#include <fstream>
#include <queue>
#include <chrono>
#include <unordered_set>
#include <unordered_map>
#include<stdlib.h>
#include<memory>
#include<vector>
#include<functional>
#include"src/data.h"
#include"src/graph.h"
struct IndexContext{
void* graph;
void* data;
};
int topk = 0;
int display_topk = 1;
int build_idx_offset = 0;
int query_idx_offset = 0;
void flush_add_buffer(
std::vector<std::pair<idx_t,std::vector<std::pair<int,value_t>>>>& add_buffer,
GraphWrapper* graph){
#pragma omp parallel for
for(int i = 0;i < add_buffer.size();++i){
auto& idx = add_buffer[i].first;
auto& point = add_buffer[i].second;
graph->add_vertex_lock(idx,point);
}
add_buffer.clear();
}
extern "C"{
// for mobius IP index
void build_mobius_index(float* dense_mat,int row,int dim, int pq_size, double mobius_pow , const char* prefix){
std::unique_ptr<Data> data;
std::unique_ptr<Data> data_original;
std::unique_ptr<GraphWrapper> graph;
int topk = 0;
int display_topk = 1;
int build_idx_offset = 0;
int query_idx_offset = 0;
++row;
data = std::unique_ptr<Data>(new Data(row,dim));
graph = std::unique_ptr<GraphWrapper>(new FixedDegreeGraph<3>(data.get()));
graph->set_construct_pq_size(pq_size);
std::vector<std::pair<idx_t,std::vector<std::pair<int,value_t>>>> add_buffer;
((FixedDegreeGraph<3>*)graph.get())->get_data()->mobius_pow = mobius_pow;
data_original = std::unique_ptr<Data>(new Data(row,dim));
std::vector<std::pair<int,value_t>> dummy_mobius_point;
for(int i = 0;i < dim;++i)
dummy_mobius_point.push_back(std::make_pair(i,0));
//idx += build_idx_offset;
for(int i = 0;i < row - 1;++i){
std::vector<std::pair<int,value_t>> point;
point.reserve(dim);
for(int j = 0;j < dim;++j)
point.push_back(std::make_pair(j,dense_mat[i * dim + j]));
data_original->add(i,point);
data->add_mobius(i,point);
if(i < 1000){
graph->add_vertex(i,point);
}else{
add_buffer.push_back(std::make_pair(i,point));
}
if(add_buffer.size() >= 1000000)
flush_add_buffer(add_buffer,graph.get());
}
flush_add_buffer(add_buffer,graph.get());
graph->add_vertex(row - 1,dummy_mobius_point);
data.swap(data_original);
std::string str = std::string(prefix);
data->dump(str + ".data");
graph->dump(str + ".graph");
}
void load_mobius_index_prefix(int row,int dim,IndexContext* index_context,const char* prefix){
std::string str = std::string(prefix);
++row;
Data* data = new Data(row,dim);
GraphWrapper* graph = new FixedDegreeGraph<1>(data);
//idx += build_idx_offset;
data->load(str + ".data");
graph->load(str + ".graph");
((FixedDegreeGraph<1>*)graph)->search_start_point = row - 1;
((FixedDegreeGraph<1>*)graph)->ignore_startpoint = true;
index_context->graph = graph;
index_context->data = data;
}
void save_mobius_index_prefix(IndexContext* index_context,const char* prefix){
std::string str = std::string(prefix);
Data* data = (Data*)(index_context->data);
GraphWrapper* graph = (GraphWrapper*)(index_context->graph);
data->dump(str + ".data");
graph->dump(str + ".graph");
}
void search_mobius_index(float* dense_vec,int dim,int search_budget,int return_k, IndexContext* index_context,idx_t* ret_id,double* ret_score){
int topk = 0;
int display_topk = 1;
int build_idx_offset = 0;
int query_idx_offset = 0;
Data* data = reinterpret_cast<Data*>(index_context->data);
GraphWrapper* graph = reinterpret_cast<GraphWrapper*>(index_context->graph);
//auto flag = (data==NULL);
//std::cout<<flag<<std::endl;
std::vector<std::pair<int,value_t>> point;
point.reserve(dim);
for(int j = 0;j < dim;++j)
point.push_back(std::make_pair(j,dense_vec[j]));
std::vector<idx_t> topN;
std::vector<double> score;
graph->search_top_k_with_score(point,search_budget,topN,score);
for(int i = 0;i < topN.size() && i < return_k;++i){
ret_id[i] = topN[i];
ret_score[i] = score[i];
}
}
// For L2 index
void build_l2_index(float* dense_mat,int row,int dim, int pq_size, const char* prefix){
std::unique_ptr<Data> data;
std::unique_ptr<GraphWrapper> graph;
int topk = 0;
int display_topk = 1;
int build_idx_offset = 0;
int query_idx_offset = 0;
data = std::unique_ptr<Data>(new Data(row,dim));
graph = std::unique_ptr<GraphWrapper>(new FixedDegreeGraph<3>(data.get()));
graph->set_construct_pq_size(pq_size);
std::vector<std::pair<idx_t,std::vector<std::pair<int,value_t>>>> add_buffer;
for(int i = 0;i < row;++i){
std::vector<std::pair<int,value_t>> point;
point.reserve(dim);
for(int j = 0;j < dim;++j)
point.push_back(std::make_pair(j,dense_mat[i * dim + j]));
data->add(i,point);
if(i < 1000){
graph->add_vertex(i,point);
}else{
add_buffer.push_back(std::make_pair(i,point));
}
if(add_buffer.size() >= 1000000)
flush_add_buffer(add_buffer,graph.get());
}
flush_add_buffer(add_buffer,graph.get());
std::string str = std::string(prefix);
data->dump(str + ".data");
graph->dump(str + ".graph");
}
void load_l2_index_prefix(int row,int dim,IndexContext* index_context,const char* prefix){
std::string str = std::string(prefix);
Data* data = new Data(row,dim);
GraphWrapper* graph = new FixedDegreeGraph<3>(data);
//idx += build_idx_offset;
data->load(str + ".data");
graph->load(str + ".graph");
index_context->graph = graph;
index_context->data = data;
}
void save_l2_index_prefix(IndexContext* index_context,const char* prefix){
std::string str = std::string(prefix);
Data* data = (Data*)(index_context->data);
GraphWrapper* graph = (GraphWrapper*)(index_context->graph);
data->dump(str + ".data");
graph->dump(str + ".graph");
}
void search_l2_index(float* dense_vec,int dim,int search_budget,int return_k, IndexContext* index_context,idx_t* ret_id,double* ret_score){
int topk = 0;
int display_topk = 1;
int build_idx_offset = 0;
int query_idx_offset = 0;
Data* data = reinterpret_cast<Data*>(index_context->data);
GraphWrapper* graph = reinterpret_cast<GraphWrapper*>(index_context->graph);
std::vector<std::pair<int,value_t>> point;
point.reserve(dim);
for(int j = 0;j < dim;++j)
point.push_back(std::make_pair(j,dense_vec[j]));
std::vector<idx_t> topN;
std::vector<double> score;
graph->search_top_k_with_score(point,search_budget,topN,score);
for(int i = 0;i < topN.size() && i < return_k;++i){
// printf("%d: (%zu, %f)\n",i,topN[i],score[i]);
ret_id[i] = topN[i];
ret_score[i] = score[i];
}
}
void release_context(IndexContext* index_context){
delete (Data*)(index_context->data);
delete (GraphWrapper*)(index_context->graph);
}
} // extern "C"
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import paddle
import numpy.ctypeslib as ctl
import numpy as np
import os
import json
from ctypes import *
from numpy.ctypeslib import ndpointer
lib = ctypes.cdll.LoadLibrary("./index.so")
class IndexContext(Structure):
_fields_=[("graph",c_void_p),
("data",c_void_p)]
# for mobius IP index
build_mobius_index = lib.build_mobius_index
build_mobius_index.restype = None
build_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_double, ctypes.c_char_p]
search_mobius_index = lib.search_mobius_index
search_mobius_index.restype = None
search_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')]
load_mobius_index_prefix = lib.load_mobius_index_prefix
load_mobius_index_prefix.restype = None
load_mobius_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p]
save_mobius_index_prefix = lib.save_mobius_index_prefix
save_mobius_index_prefix.restype = None
save_mobius_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]
# for L2 index
build_l2_index = lib.build_l2_index
build_l2_index.restype = None
build_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p]
search_l2_index = lib.search_l2_index
search_l2_index.restype = None
search_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')]
load_l2_index_prefix = lib.load_l2_index_prefix
load_l2_index_prefix.restype = None
load_l2_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p]
save_l2_index_prefix = lib.save_l2_index_prefix
save_l2_index_prefix.restype = None
save_l2_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]
release_context = lib.release_context
release_context.restype = None
release_context.argtypes = [POINTER(IndexContext)]
class Graph_Index(object):
"""
graph index
"""
def __init__(self, dist_type="IP"):
self.dim = 0
self.total_num = 0
self.dist_type = dist_type
self.mobius_pow = 2.0
self.index_context = IndexContext(0,0)
self.gallery_doc_dict = {}
self.with_attr = False
assert dist_type in ["IP", "L2"], "Only support IP and L2 distance ..."
def build(self, gallery_vectors, gallery_docs=[], pq_size=100, index_path='graph_index/'):
"""
build index
"""
if paddle.is_tensor(gallery_vectors):
gallery_vectors = gallery_vectors.numpy()
assert gallery_vectors.ndim == 2, "Input vector must be 2D ..."
self.total_num = gallery_vectors.shape[0]
self.dim = gallery_vectors.shape[1]
assert (len(gallery_docs) == self.total_num if len(gallery_docs)>0 else True)
print("training index -> num: {}, dim: {}, dist_type: {}".format(self.total_num, self.dim, self.dist_type))
if not os.path.exists(index_path):
os.makedirs(index_path)
if self.dist_type == "IP":
build_mobius_index(gallery_vectors,self.total_num,self.dim, pq_size, self.mobius_pow, create_string_buffer((index_path+"/index").encode('utf-8')))
load_mobius_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
else:
build_l2_index(gallery_vectors,self.total_num,self.dim, pq_size, create_string_buffer((index_path+"/index").encode('utf-8')))
load_l2_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
self.gallery_doc_dict = {}
if len(gallery_docs) > 0:
self.with_attr = True
for i in range(gallery_vectors.shape[0]):
self.gallery_doc_dict[str(i)] = gallery_docs[i]
self.gallery_doc_dict["total_num"] = self.total_num
self.gallery_doc_dict["dim"] = self.dim
self.gallery_doc_dict["dist_type"] = self.dist_type
self.gallery_doc_dict["with_attr"] = self.with_attr
with open(index_path + "/info.json", "w") as f:
json.dump(self.gallery_doc_dict, f)
print("finished creating index ...")
def search(self, query, return_k=10, search_budget=100):
"""
search
"""
ret_id = np.zeros(return_k, dtype=np.uint64)
ret_score = np.zeros(return_k, dtype=np.float64)
if paddle.is_tensor(query):
query = query.numpy()
if self.dist_type == "IP":
search_mobius_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score)
else:
search_l2_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score)
ret_id = ret_id.tolist()
ret_doc = []
if self.with_attr:
for i in range(return_k):
ret_doc.append(self.gallery_doc_dict[str(ret_id[i])])
return ret_score, ret_doc
else:
return ret_score, ret_id
def dump(self, index_path):
if not os.path.exists(index_path):
os.makedirs(index_path)
if self.dist_type == "IP":
save_mobius_index_prefix(ctypes.byref(self.index_context),create_string_buffer((index_path+"/index").encode('utf-8')))
else:
save_l2_index_prefix(ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
with open(index_path + "/info.json", "w") as f:
json.dump(self.gallery_doc_dict, f)
def load(self, index_path):
self.gallery_doc_dict = {}
with open(index_path + "/info.json", "r") as f:
self.gallery_doc_dict = json.load(f)
self.total_num = self.gallery_doc_dict["total_num"]
self.dim = self.gallery_doc_dict["dim"]
self.dist_type = self.gallery_doc_dict["dist_type"]
self.with_attr = self.gallery_doc_dict["with_attr"]
if self.dist_type == "IP":
load_mobius_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
else:
load_l2_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
# MIT License
#
#Copyright (c) 2021 Mobius Authors
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
#from https://github.com/sunbelbd/mobius/blob/e2d166547d61d791da8f06747a63b9cd38f02c71/config.h
#pragma once
typedef float value_t;
//typedef double dist_t;
typedef float dist_t;
typedef size_t idx_t;
typedef int UINT;
#define ACC_BATCH_SIZE 4096
#define FIXED_DEGREE 31
#define FIXED_DEGREE_SHIFT 5
//for construction
#define SEARCH_DEGREE 15
#define CONSTRUCT_SEARCH_BUDGET 150
# MIT License
#
#Copyright (c) 2021 Mobius Authors
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
#from https://github.com/sunbelbd/mobius/blob/e2d166547d61d791da8f06747a63b9cd38f02c71/data.h
#pragma once
#include<memory>
#include<vector>
#include<math.h>
#include"config.h"
#define ZERO_EPS 1e-10
#define _SCALE_WORLD_DENSE_DATA
#ifdef _SCALE_WORLD_DENSE_DATA
//dense data
class Data{
private:
std::unique_ptr<value_t[]> data;
size_t num;
size_t curr_num = 0;
int dim;
public:
value_t mobius_pow = 2;
value_t max_ip_norm = 1;
value_t max_ip_norm2 = 1;
Data(size_t num, int dim) : num(num),dim(dim){
data = std::unique_ptr<value_t[]>(new value_t[num * dim]);
memset(data.get(),0,sizeof(value_t) * num * dim);
}
value_t* get(idx_t idx) const{
return data.get() + idx * dim;
}
template<class T>
dist_t ipwrap_l2_query_distance(idx_t a,T& v) const{
auto pa = get(a);
dist_t ret = 0;
dist_t normu = 0;
for(int i = 0;i < dim;++i){
auto diff = (*(pa + i) / max_ip_norm) - v[i];
ret += diff * diff;
normu += (*(pa + i)) * (*(pa + i));
}
ret += 1 - normu / max_ip_norm2;
return ret;
}
template<class T>
dist_t ipwrap_l2_build_distance(idx_t a,T& v) const{
auto pa = get(a);
dist_t ret = 0;
dist_t normu = 0;
dist_t normv = 0;
for(int i = 0;i < dim;++i){
auto diff = *(pa + i) - v[i];
ret += diff * diff;
normu += (*(pa + i)) * (*(pa + i));
normv += v[i] * v[i];
}
dist_t wrap_termu = sqrt(1 - normu / max_ip_norm2);
dist_t wrap_termv = sqrt(1 - normv / max_ip_norm2);
dist_t diff_wrap = wrap_termu - wrap_termv;
ret = ret / max_ip_norm2 + diff_wrap * diff_wrap;
return ret;
}
template<class T>
dist_t l2_distance(idx_t a,T& v) const{
auto pa = get(a);
dist_t ret = 0;
for(int i = 0;i < dim;++i){
auto diff = *(pa + i) - v[i];
ret += diff * diff;
}
return ret;
}
template<class T>
dist_t negative_inner_prod_distance(idx_t a,T& v) const{
auto pa = get(a);
dist_t ret = 0;
for(int i = 0;i < dim;++i){
ret -= (*(pa + i)) * v[i];
}
return ret;
}
template<class T>
dist_t negative_cosine_distance(idx_t a,T& v) const{
auto pa = get(a);
dist_t ret = 0;
value_t lena = 0,lenv = 0;
for(int i = 0;i < dim;++i){
ret += (*(pa + i)) * v[i];
lena += (*(pa + i)) * (*(pa + i));
lenv += v[i] * v[i];
}
int sign = ret < 0 ? 1 : -1;
// return sign * (ret * ret / lena);// / lenv);
return sign * (ret * ret / lena / lenv);
}
template<class T>
dist_t mobius_l2_distance(idx_t a,T& v) const{
auto pa = get(a);
dist_t ret = 0;
value_t lena = 0,lenv = 0;
for(int i = 0;i < dim;++i){
lena += (*(pa + i)) * (*(pa + i));
lenv += v[i] * v[i];
}
value_t modifier_a = pow(lena,0.5 * mobius_pow);
value_t modifier_v = pow(lenv,0.5 * mobius_pow);
if(fabs(modifier_a) < ZERO_EPS)
modifier_a = 1;
if(fabs(modifier_v) < ZERO_EPS)
modifier_v = 1;
for(int i = 0;i < dim;++i){
value_t tmp = (*(pa + i)) / modifier_a - v[i] / modifier_v;
ret += tmp * tmp;
}
return ret;
}
template<class T>
dist_t real_nn(T& v) const{
dist_t minn = 1e100;
for(size_t i = 0;i < curr_num;++i){
auto res = l2_distance(i,v);
if(res < minn){
minn = res;
}
}
return minn;
}
std::vector<value_t> organize_point_mobius(const std::vector<std::pair<int,value_t>>& v){
std::vector<value_t> ret(dim,0);
value_t lena = 0;
for(const auto& p : v){
// ret[p.first] = p.second;
lena += p.second * p.second;
}
value_t modifier_a = pow(lena,0.5 * mobius_pow);
if(fabs(modifier_a) < ZERO_EPS)
modifier_a = 1;
for(const auto& p : v){
ret[p.first] = p.second / modifier_a;
}
return std::move(ret);
}
std::vector<value_t> organize_point(const std::vector<std::pair<int,value_t>>& v){
std::vector<value_t> ret(dim,0);
for(const auto& p : v){
if(p.first >= dim)
printf("error %d %d\n",p.first,dim);
ret[p.first] = p.second;
}
return std::move(ret);
}
value_t vec_sum2(const std::vector<std::pair<int,value_t>>& v){
value_t ret = 0;
for(const auto& p : v){
if(p.first >= dim)
printf("error %d %d\n",p.first,dim);
ret += p.second * p.second;
}
return std::move(ret);
}
void add(idx_t idx, std::vector<std::pair<int,value_t>>& value){
//printf("adding %zu\n",idx);
//for(auto p : value)
// printf("%zu %d %f\n",idx,p.first,p.second);
curr_num = std::max(curr_num,idx);
auto p = get(idx);
for(const auto& v : value)
*(p + v.first) = v.second;
}
void add_mobius(idx_t idx, std::vector<std::pair<int,value_t>>& value){
//printf("adding %zu\n",idx);
//for(auto p : value)
// printf("%zu %d %f\n",idx,p.first,p.second);
curr_num = std::max(curr_num,idx);
auto p = get(idx);
value_t lena = 0;
for(const auto& v : value){
*(p + v.first) = v.second;
lena += v.second * v.second;
}
value_t modifier_a = pow(lena,0.5 * mobius_pow);
if(fabs(modifier_a) < ZERO_EPS)
modifier_a = 1;
for(const auto& v : value){
*(p + v.first) = v.second / modifier_a;
}
}
inline size_t max_vertices(){
return num;
}
inline size_t curr_vertices(){
return curr_num;
}
void print(){
for(int i = 0;i < num && i < 10;++i)
printf("%f ",*(data.get() + i));
printf("\n");
}
int get_dim(){
return dim;
}
void dump(std::string path = "bfsg.data"){
FILE* fp = fopen(path.c_str(),"wb");
fwrite(data.get(),sizeof(value_t) * num * dim,1,fp);
fclose(fp);
}
void load(std::string path = "bfsg.data"){
curr_num = num;
FILE* fp = fopen(path.c_str(),"rb");
auto cnt = fread(data.get(),sizeof(value_t) * num * dim,1,fp);
fclose(fp);
}
};
template<>
dist_t Data::ipwrap_l2_build_distance(idx_t a,idx_t& b) const{
auto pa = get(a);
auto pb = get(b);
dist_t ret = 0;
dist_t normu = 0;
dist_t normv = 0;
for(int i = 0;i < dim;++i){
auto diff = *(pa + i) - *(pb + i);
ret += diff * diff;
normu += (*(pa + i)) * (*(pa + i));
normv += (*(pb + i)) * (*(pb + i));
}
dist_t wrap_termu = sqrt(1 - normu / max_ip_norm2);
dist_t wrap_termv = sqrt(1 - normv / max_ip_norm2);
dist_t diff_wrap = wrap_termu - wrap_termv;
ret = ret / max_ip_norm2 + diff_wrap * diff_wrap;
return ret;
}
template<>
dist_t Data::ipwrap_l2_query_distance(idx_t a,idx_t& b) const{
auto pa = get(a);
auto pb = get(b);
dist_t ret = 0;
dist_t normu = 0;
for(int i = 0;i < dim;++i){
auto diff = (*(pa + i) / max_ip_norm) - *(pb + i);
ret += diff * diff;
normu += (*(pa + i)) * (*(pa + i));
}
ret += 1 - normu / max_ip_norm2;
return ret;
}
template<>
dist_t Data::l2_distance(idx_t a,idx_t& b) const{
auto pa = get(a),
pb = get(b);
dist_t ret = 0;
for(int i = 0;i < dim;++i){
auto diff = *(pa + i) - *(pb + i);
ret += diff * diff;
}
return ret;
}
template<>
dist_t Data::negative_inner_prod_distance(idx_t a,idx_t& b) const{
auto pa = get(a),
pb = get(b);
dist_t ret = 0;
for(int i = 0;i < dim;++i){
ret -= (*(pa + i)) * (*(pb + i));
}
return ret;
}
template<>
dist_t Data::negative_cosine_distance(idx_t a,idx_t& b) const{
auto pa = get(a),
pb = get(b);
dist_t ret = 0;
value_t lena = 0,lenv = 0;
for(int i = 0;i < dim;++i){
ret += (*(pa + i)) * (*(pb + i));
lena += (*(pa + i)) * (*(pa + i));
lenv += (*(pb + i)) * (*(pb + i));
}
int sign = ret < 0 ? 1 : -1;
// return sign * (ret * ret / lena);
return sign * (ret * ret / lena / lenv);
}
template<>
dist_t Data::mobius_l2_distance(idx_t a,idx_t& b) const{
auto pa = get(a),
pb = get(b);
dist_t ret = 0;
value_t lena = 0,lenv = 0;
for(int i = 0;i < dim;++i){
lena += (*(pa + i)) * (*(pa + i));
lenv += (*(pb + i)) * (*(pb + i));
}
value_t modifier_a = pow(lena,0.5 * mobius_pow);
value_t modifier_v = pow(lenv,0.5 * mobius_pow);
if(fabs(modifier_a) < ZERO_EPS)
modifier_a = 1;
if(fabs(modifier_v) < ZERO_EPS)
modifier_v = 1;
for(int i = 0;i < dim;++i){
value_t tmp = (*(pa + i)) / modifier_a - (*(pb + i)) / modifier_v;
ret += tmp * tmp;
}
return ret;
}
#else
//sparse data
class Data{
public:
//TODO
};
#endif
# MIT License
#
#Copyright (c) 2021 Mobius Authors
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
#from https://github.com/sunbelbd/mobius/blob/e2d166547d61d791da8f06747a63b9cd38f02c71/graph.h
#pragma once
#include<vector>
#include<algorithm>
#include<queue>
#include<stdlib.h>
#include<random>
#include<unordered_set>
#include<mutex>
#include<time.h>
#include"config.h"
#include"data.h"
#ifdef OMP
#include<omp.h>
#endif
typedef unsigned int vl_type;
class VisitedList {
public:
vl_type curV;
vl_type *mass;
unsigned int numelements;
VisitedList(int numelements1) {
curV = 1;
numelements = numelements1;
mass = new vl_type[numelements];
memset(mass, 0, sizeof(vl_type) * numelements);
}
void reset() {
++curV;
if (curV == 0) {
curV = 1;
memset(mass, 0, sizeof(vl_type) * numelements);
}
};
~VisitedList() { delete mass; }
};
struct GraphMeasures{
int distance_cnt = 0;
};
class GraphWrapper{
public:
virtual void add_vertex(idx_t vertex_id,std::vector<std::pair<int,value_t>>& point) = 0;
virtual void add_vertex_lock(idx_t vertex_id,std::vector<std::pair<int,value_t>>& point) = 0;
virtual void search_top_k(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result) = 0;
virtual void search_top_k_with_score(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result,std::vector<double>& score){}
virtual void dump(std::string path = "bfsg.graph") = 0;
virtual void load(std::string path = "bfsg.graph") = 0;
virtual ~GraphWrapper(){}
virtual void set_construct_pq_size(int size){};
GraphMeasures measures;
};
template<const int dist_type>
class FixedDegreeGraph : public GraphWrapper{
private:
const int degree = SEARCH_DEGREE;
const int flexible_degree = FIXED_DEGREE;
const int vertex_offset_shift = FIXED_DEGREE_SHIFT;
std::vector<idx_t> edges;
std::vector<dist_t> edge_dist;
Data* data;
std::mt19937_64 rand_gen = std::mt19937_64(1234567);//std::random_device{}());
std::vector<std::mutex> edge_mutex;//do not push back on this vector, it will destroy the mutex
bool debug = false;
VisitedList* p_visited = NULL;
#ifdef OMP
std::vector<VisitedList*> visited_pool;
#endif
int construct_pq_size = CONSTRUCT_SEARCH_BUDGET;
void rank_and_switch_ordered(idx_t v_id,idx_t u_id){
//We assume the neighbors of v_ids in edges[offset] are sorted
//by the distance to v_id ascendingly when it is full
//NOTICE: before it is full, it is unsorted
auto curr_dist = pair_distance(v_id,u_id);
auto offset = ((size_t)v_id) << vertex_offset_shift;
int degree = edges[offset];
std::vector<idx_t> neighbor;
neighbor.reserve(degree + 1);
for(int i = 0;i < degree;++i)
neighbor.push_back(edges[offset + i + 1]);
neighbor.push_back(u_id);
neighbor = edge_selection_filter_neighbor(neighbor,v_id,flexible_degree);
edges[offset] = neighbor.size();
for(int i = 0;i < neighbor.size();++i)
edges[offset + i + 1] = neighbor[i];
return;
//We assert edges[offset] > 0 here
if(curr_dist >= edge_dist[offset + edges[offset]]){
return;
}
edges[offset + edges[offset]] = u_id;
edge_dist[offset + edges[offset]] = curr_dist;
for(size_t i = offset + edges[offset] - 1;i > offset;--i){
if(edge_dist[i] > edge_dist[i + 1]){
std::swap(edges[i],edges[i + 1]);
std::swap(edge_dist[i],edge_dist[i + 1]);
}else{
break;
}
}
}
void rank_and_switch(idx_t v_id,idx_t u_id){
rank_and_switch_ordered(v_id,u_id);
//TODO:
//Implement an unordered version to compare with
}
template<class T>
dist_t distance(idx_t a,T& b){
if(dist_type == 0)
return data->l2_distance(a,b);
else if(dist_type == 1)
return data->negative_inner_prod_distance(a,b);
else if(dist_type == 2)
return data->negative_cosine_distance(a,b);
else if(dist_type == 3)
return data->l2_distance(a,b);
else if(dist_type == 4)
return data->ipwrap_l2_build_distance(a,b);
else if(dist_type == 5)
return data->ipwrap_l2_query_distance(a,b);
else{
// should not happen
fprintf(stderr,"unsupported dist_type %d\n",dist_type);
return 0;
}
}
void compute_distance_naive(size_t offset,std::vector<dist_t>& dists){
dists.resize(edges[offset]);
auto degree = edges[offset];
for(int i = 0;i < degree;++i){
dists[i] = distance(offset >> vertex_offset_shift,edges[offset + i + 1]);
}
}
void compute_distance(size_t offset,std::vector<dist_t>& dists){
compute_distance_naive(offset,dists);
}
template<class T>
dist_t pair_distance_naive(idx_t a,T& b){
++measures.distance_cnt;
return distance(a,b);
}
template<class T>
dist_t pair_distance(idx_t a,T& b){
return pair_distance_naive(a,b);
}
void qsort(size_t l,size_t r){
auto mid = (l + r) >> 1;
int i = l,j = r;
auto k = edge_dist[mid];
do{
while(edge_dist[i] < k) ++i;
while(k < edge_dist[j]) --j;
if(i <= j){
std::swap(edge_dist[i],edge_dist[j]);
std::swap(edges[i],edges[j]);
++i;
--j;
}
}while(i <= j);
if(i < r)qsort(i,r);
if(l < j)qsort(l,j);
}
void rank_edges(size_t offset){
std::vector<dist_t> dists;
compute_distance(offset,dists);
for(int i = 0;i < dists.size();++i)
edge_dist[offset + i + 1] = dists[i];
qsort(offset + 1,offset + dists.size());
//TODO:
//use a heap in the edge_dist
}
void add_edge_lock(idx_t v_id,idx_t u_id){
edge_mutex[v_id].lock();
auto offset = ((size_t)v_id) << vertex_offset_shift;
if(edges[offset] < flexible_degree){
++edges[offset];
edges[offset + edges[offset]] = u_id;
}else{
rank_and_switch(v_id,u_id);
}
edge_mutex[v_id].unlock();
}
void add_edge(idx_t v_id,idx_t u_id){
auto offset = ((size_t)v_id) << vertex_offset_shift;
if(edges[offset] < flexible_degree){
++edges[offset];
edges[offset + edges[offset]] = u_id;
}else{
rank_and_switch(v_id,u_id);
}
}
public:
long long total_explore_cnt = 0;
int total_explore_times = 0;
size_t search_start_point = 0;
bool ignore_startpoint = false;
FixedDegreeGraph(Data* data) : data(data){
auto num_vertices = data->max_vertices();
edges = std::vector<idx_t>(((size_t)num_vertices) << vertex_offset_shift);
edge_dist = std::vector<dist_t>(((size_t)num_vertices) << vertex_offset_shift);
edge_mutex = std::vector<std::mutex>(num_vertices);
p_visited = new VisitedList(num_vertices + 5);
#ifdef OMP
int n_threads = 1;
#pragma omp parallel
#pragma omp master
{
n_threads = omp_get_num_threads();
}
visited_pool.resize(n_threads);
for(int i = 0;i < n_threads;++i)
visited_pool[i] = new VisitedList(num_vertices + 5);
#endif
}
void set_construct_pq_size(int size){
construct_pq_size = size;
}
std::vector<idx_t> edge_selection_filter_neighbor(std::vector<idx_t>& neighbor,idx_t vertex_id,int desired_size){
std::vector<idx_t> filtered_neighbor;
std::vector<dist_t> dists(neighbor.size());
for(int i = 0;i < dists.size();++i)
dists[i] = pair_distance(vertex_id,neighbor[i]);
std::vector<int> idx(neighbor.size());
for(int i = 0;i < idx.size();++i)
idx[i] = i;
std::sort(idx.begin(),idx.end(),[&](int a,int b){return dists[a] < dists[b];});
for(int i = 0;i < idx.size();++i){
dist_t cur_dist = dists[idx[i]];
bool pass = true;
for(auto neighbor_id : filtered_neighbor){
if(cur_dist > pair_distance(neighbor_id,neighbor[idx[i]])){
pass = false;
break;
}
}
if(pass){
filtered_neighbor.push_back(neighbor[idx[i]]);
if(filtered_neighbor.size() >= desired_size)
break;
}else{
}
}
return std::move(filtered_neighbor);
}
void add_vertex_lock(idx_t vertex_id,std::vector<std::pair<int,value_t>>& point){
std::vector<idx_t> neighbor;
search_top_k_lock(point,construct_pq_size,neighbor);
auto offset = ((size_t)vertex_id) << vertex_offset_shift;
int num_neighbors = degree < neighbor.size() ? degree : neighbor.size();
edge_mutex[vertex_id].lock();
// TODO:
// it is possible to save this space --- edges[offset]
// by set the last number in the range as
// a large number - current degree
if(neighbor.size() >= degree)
neighbor = edge_selection_filter_neighbor(neighbor,vertex_id,degree);
edges[offset] = neighbor.size();
for(int i = 0;i < neighbor.size() && i < degree;++i){
edges[offset + i + 1] = neighbor[i];
}
edge_mutex[vertex_id].unlock();
for(int i = 0;i < neighbor.size() && i < degree;++i){
add_edge_lock(neighbor[i],vertex_id);
}
}
void add_vertex(idx_t vertex_id,std::vector<std::pair<int,value_t>>& point){
std::vector<idx_t> neighbor;
search_top_k(point,construct_pq_size,neighbor);
auto offset = ((size_t)vertex_id) << vertex_offset_shift;
int num_neighbors = degree < neighbor.size() ? degree : neighbor.size();
// TODO:
// it is possible to save this space --- edges[offset]
// by set the last number in the range as
// a large number - current degree
if(neighbor.size() >= degree){
neighbor = edge_selection_filter_neighbor(neighbor,vertex_id,degree);
}
edges[offset] = neighbor.size();
for(int i = 0;i < neighbor.size() && i < degree;++i){
edges[offset + i + 1] = neighbor[i];
}
for(int i = 0;i < neighbor.size() && i < degree;++i){
add_edge(neighbor[i],vertex_id);
}
}
void astar_multi_start_search_lock(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result){
std::priority_queue<std::pair<dist_t,idx_t>,std::vector<std::pair<dist_t,idx_t>>,std::greater<std::pair<dist_t,idx_t>>> q;
const int num_start_point = 1;
auto converted_query = dist_type == 3 ? data->organize_point_mobius(query) : data->organize_point(query);
#ifdef OMP
int tid = omp_get_thread_num();
auto& p_visited = visited_pool[tid];
#endif
p_visited->reset();
auto tag = p_visited->curV;
for(int i = 0;i < num_start_point && i < data->curr_vertices();++i){
auto start = search_start_point;//rand_gen() % data->curr_vertices();
if(p_visited->mass[start] == tag)
continue;
p_visited->mass[start] = tag;
q.push(std::make_pair(pair_distance_naive(start,converted_query),start));
}
std::priority_queue<std::pair<dist_t,idx_t>> topk;
const int max_step = 1000000;
bool found_min_node = false;
dist_t min_dist = 1e100;
int explore_cnt = 0;
for(int iter = 0;iter < max_step && !q.empty();++iter){
auto now = q.top();
if(topk.size() == k && topk.top().first < now.first){
break;
}
++explore_cnt;
min_dist = std::min(min_dist,now.first);
q.pop();
if(ignore_startpoint == false || iter != 0)
topk.push(now);
if(topk.size() > k)
topk.pop();
edge_mutex[now.second].lock();
auto offset = ((size_t)now.second) << vertex_offset_shift;
auto degree = edges[offset];
for(int i = 0;i < degree;++i){
auto start = edges[offset + i + 1];
if(p_visited->mass[start] == tag)
continue;
p_visited->mass[start] = tag;
auto dist = pair_distance_naive(start,converted_query);
if(topk.empty() || dist < topk.top().first || topk.size() < k)
q.push(std::make_pair(dist,start));
}
edge_mutex[now.second].unlock();
}
total_explore_cnt += explore_cnt;
++total_explore_times;
result.resize(topk.size());
int i = result.size() - 1;
while(!topk.empty()){
result[i] = (topk.top().second);
topk.pop();
--i;
}
}
void astar_no_heap_search(const std::vector<std::pair<int,value_t>>& query,std::vector<idx_t>& result){
const int num_start_point = 1;
std::pair<dist_t,idx_t> q_top = std::make_pair(10000000000,0);
auto converted_query = dist_type == 3 ? data->organize_point_mobius(query) : data->organize_point(query);
p_visited->reset();
auto tag = p_visited->curV;
for(int i = 0;i < num_start_point && i < data->curr_vertices();++i){
auto start = search_start_point;//rand_gen() % data->curr_vertices();
p_visited->mass[start] = tag;
if(ignore_startpoint == false){
q_top = (std::make_pair(pair_distance_naive(start,converted_query),start));
}else{
auto offset = ((size_t)start) << vertex_offset_shift;
auto degree = edges[offset];
for(int i = 1;i <= degree;++i){
p_visited->mass[edges[offset + i]] = tag;
auto dis = pair_distance_naive(edges[offset + i],converted_query);
if(dis < q_top.first)
q_top = (std::make_pair(dis,start));
}
}
}
const int max_step = 1000000;
bool found_min_node = false;
dist_t min_dist = 1e100;
int explore_cnt = 0;
for(int iter = 0;iter < max_step;++iter){
++explore_cnt;
auto offset = ((size_t)q_top.second) << vertex_offset_shift;
auto degree = edges[offset];
bool changed = false;
for(int i = 0;i < degree;++i){
auto start = edges[offset + i + 1];
if(p_visited->mass[start] == tag)
continue;
p_visited->mass[start] = tag;
auto dist = pair_distance_naive(start,converted_query);
if(dist < q_top.first){
q_top = (std::make_pair(dist,start));
changed = true;
}
}
if(changed == false)
break;
}
total_explore_cnt += explore_cnt;
++total_explore_times;
result.resize(1);
result[0] = q_top.second;
}
void astar_multi_start_search_with_score(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result,std::vector<double>& score){
std::priority_queue<std::pair<dist_t,idx_t>,std::vector<std::pair<dist_t,idx_t>>,std::greater<std::pair<dist_t,idx_t>>> q;
const int num_start_point = 1;
auto converted_query = dist_type == 3 ? data->organize_point_mobius(query) : data->organize_point(query);
p_visited->reset();
auto tag = p_visited->curV;
for(int i = 0;i < num_start_point && i < data->curr_vertices();++i){
auto start = search_start_point;//rand_gen() % data->curr_vertices();
if(p_visited->mass[start] == tag)
continue;
p_visited->mass[start] = tag;
q.push(std::make_pair(pair_distance_naive(start,converted_query),start));
}
std::priority_queue<std::pair<dist_t,idx_t>> topk;
const int max_step = 1000000;
bool found_min_node = false;
dist_t min_dist = 1e100;
int explore_cnt = 0;
for(int iter = 0;iter < max_step && !q.empty();++iter){
auto now = q.top();
if(topk.size() == k && topk.top().first < now.first){
break;
}
++explore_cnt;
min_dist = std::min(min_dist,now.first);
q.pop();
if(ignore_startpoint == false || iter != 0)
topk.push(now);
if(topk.size() > k)
topk.pop();
auto offset = ((size_t)now.second) << vertex_offset_shift;
auto degree = edges[offset];
for(int i = 0;i < degree;++i){
auto start = edges[offset + i + 1];
if(p_visited->mass[start] == tag)
continue;
p_visited->mass[start] = tag;
auto dist = pair_distance_naive(start,converted_query);
if(topk.empty() || dist < topk.top().first || topk.size() < k)
q.push(std::make_pair(dist,start));
}
}
total_explore_cnt += explore_cnt;
++total_explore_times;
result.resize(topk.size());
score.resize(topk.size());
int i = result.size() - 1;
while(!topk.empty()){
result[i] = (topk.top().second);
score[i] = -(topk.top().first);
topk.pop();
--i;
}
}
void astar_multi_start_search(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result){
std::priority_queue<std::pair<dist_t,idx_t>,std::vector<std::pair<dist_t,idx_t>>,std::greater<std::pair<dist_t,idx_t>>> q;
const int num_start_point = 1;
auto converted_query = dist_type == 3 ? data->organize_point_mobius(query) : data->organize_point(query);
p_visited->reset();
auto tag = p_visited->curV;
for(int i = 0;i < num_start_point && i < data->curr_vertices();++i){
auto start = search_start_point;//rand_gen() % data->curr_vertices();
if(p_visited->mass[start] == tag)
continue;
p_visited->mass[start] = tag;
q.push(std::make_pair(pair_distance_naive(start,converted_query),start));
}
std::priority_queue<std::pair<dist_t,idx_t>> topk;
const int max_step = 1000000;
bool found_min_node = false;
dist_t min_dist = 1e100;
int explore_cnt = 0;
for(int iter = 0;iter < max_step && !q.empty();++iter){
auto now = q.top();
if(topk.size() == k && topk.top().first < now.first){
break;
}
++explore_cnt;
min_dist = std::min(min_dist,now.first);
q.pop();
if(ignore_startpoint == false || iter != 0)
topk.push(now);
if(topk.size() > k)
topk.pop();
auto offset = ((size_t)now.second) << vertex_offset_shift;
auto degree = edges[offset];
for(int i = 0;i < degree;++i){
auto start = edges[offset + i + 1];
if(p_visited->mass[start] == tag)
continue;
p_visited->mass[start] = tag;
auto dist = pair_distance_naive(start,converted_query);
if(topk.empty() || dist < topk.top().first || topk.size() < k)
q.push(std::make_pair(dist,start));
}
}
total_explore_cnt += explore_cnt;
++total_explore_times;
result.resize(topk.size());
int i = result.size() - 1;
while(!topk.empty()){
result[i] = (topk.top().second);
topk.pop();
--i;
}
}
void search_top_k(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result){
if(k == 1)
astar_no_heap_search(query,result);
else
astar_multi_start_search(query,k,result);
}
void search_top_k_with_score(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result,std::vector<double>& score){
astar_multi_start_search_with_score(query,k,result,score);
}
void search_top_k_lock(const std::vector<std::pair<int,value_t>>& query,int k,std::vector<idx_t>& result){
astar_multi_start_search_lock(query,k,result);
}
void print_stat(){
auto n = data->max_vertices();
size_t sum = 0;
std::vector<size_t> histogram(2 * degree + 1,0);
for(size_t i = 0;i < n;++i){
sum += edges[i << vertex_offset_shift];
int tmp = edges[i << vertex_offset_shift];
if(tmp > 2 * degree + 1)
fprintf(stderr,"[ERROR] node %zu has %d degree\n",i,tmp);
++histogram[edges[i << vertex_offset_shift]];
if(tmp != degree)
fprintf(stderr,"[INFO] %zu has degree %d\n",i,tmp);
}
fprintf(stderr,"[INFO] #vertices %zu, avg degree %f\n",n,sum * 1.0 / n);
std::unordered_set<idx_t> visited;
fprintf(stderr,"[INFO] degree histogram:\n");
for(int i = 0;i <= 2 * degree + 1;++i)
fprintf(stderr,"[INFO] %d:\t%zu\n",i,histogram[i]);
}
void print_edges(int x){
for(size_t i = 0;i < x;++i){
size_t offset = i << vertex_offset_shift;
int degree = edges[offset];
fprintf(stderr,"%d (%d): ",i,degree);
for(int j = 1;j <= degree;++j)
fprintf(stderr,"(%zu,%f) ",edges[offset + j],edge_dist[offset + j]);
fprintf(stderr,"\n");
}
}
void dump(std::string path = "bfsg.graph"){
FILE* fp = fopen(path.c_str(),"wb");
size_t num_vertices = data->max_vertices();
fwrite(&edges[0],sizeof(edges[0]) * (num_vertices << vertex_offset_shift),1,fp);
fclose(fp);
}
void load(std::string path = "bfsg.graph"){
FILE* fp = fopen(path.c_str(),"rb");
size_t num_vertices = data->max_vertices();
auto cnt = fread(&edges[0],sizeof(edges[0]) * (num_vertices << vertex_offset_shift),1,fp);
fclose(fp);
}
Data* get_data(){
return data;
}
};
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from interface import Graph_Index
# 随机产生样本
index_vectors = np.random.rand(100000,128).astype(np.float32)
query_vector = np.random.rand(128).astype(np.float32)
index_docs = ["ID_"+str(i) for i in range(100000)]
# 初始化索引结构
indexer = Graph_Index(dist_type="IP") #支持"IP"和"L2"
indexer.build(gallery_vectors=index_vectors, gallery_docs=index_docs, pq_size=100, index_path='test')
# 查询
scores, docs = indexer.search(query=query_vector, return_k=10, search_budget=100)
print(scores)
print(docs)
# 保存与加载
indexer.dump(index_path="test")
indexer.load(index_path="test")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册