Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
829cb4bb
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
829cb4bb
编写于
2月 08, 2020
作者:
T
Tinkerrr
提交者:
GitHub
2月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support hnsw (#1131)
* add hnsw * add config * format... * format..
上级
4bd3b62b
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
2310 addition
and
13 deletion
+2310
-13
core/src/db/engine/ExecutionEngine.h
core/src/db/engine/ExecutionEngine.h
+4
-3
core/src/db/engine/ExecutionEngineImpl.cpp
core/src/db/engine/ExecutionEngineImpl.cpp
+9
-4
core/src/index/knowhere/CMakeLists.txt
core/src/index/knowhere/CMakeLists.txt
+1
-0
core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp
.../index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp
+170
-0
core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h
...rc/index/knowhere/knowhere/index/vector_index/IndexHNSW.h
+69
-0
core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h
...ex/knowhere/knowhere/index/vector_index/helpers/FaissIO.h
+12
-0
core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h
...here/knowhere/index/vector_index/helpers/IndexParameter.h
+12
-0
core/src/index/thirdparty/hnswlib/bruteforce.h
core/src/index/thirdparty/hnswlib/bruteforce.h
+170
-0
core/src/index/thirdparty/hnswlib/hnswalg.h
core/src/index/thirdparty/hnswlib/hnswalg.h
+1160
-0
core/src/index/thirdparty/hnswlib/hnswlib.h
core/src/index/thirdparty/hnswlib/hnswlib.h
+98
-0
core/src/index/thirdparty/hnswlib/space_ip.h
core/src/index/thirdparty/hnswlib/space_ip.h
+248
-0
core/src/index/thirdparty/hnswlib/space_l2.h
core/src/index/thirdparty/hnswlib/space_l2.h
+244
-0
core/src/index/thirdparty/hnswlib/visited_list_pool.h
core/src/index/thirdparty/hnswlib/visited_list_pool.h
+78
-0
core/src/wrapper/ConfAdapter.cpp
core/src/wrapper/ConfAdapter.cpp
+16
-4
core/src/wrapper/ConfAdapter.h
core/src/wrapper/ConfAdapter.h
+8
-2
core/src/wrapper/ConfAdapterMgr.cpp
core/src/wrapper/ConfAdapterMgr.cpp
+3
-0
core/src/wrapper/VecIndex.cpp
core/src/wrapper/VecIndex.cpp
+7
-0
core/src/wrapper/VecIndex.h
core/src/wrapper/VecIndex.h
+1
-0
未找到文件。
core/src/db/engine/ExecutionEngine.h
浏览文件 @
829cb4bb
...
...
@@ -17,12 +17,12 @@
#pragma once
#include "utils/Status.h"
#include <memory>
#include <string>
#include <vector>
#include "utils/Status.h"
namespace
milvus
{
namespace
engine
{
...
...
@@ -39,7 +39,8 @@ enum class EngineType {
SPTAG_BKT
,
FAISS_BIN_IDMAP
,
FAISS_BIN_IVFFLAT
,
MAX_VALUE
=
FAISS_BIN_IVFFLAT
,
HNSW
,
MAX_VALUE
=
HNSW
,
};
enum
class
MetricType
{
...
...
core/src/db/engine/ExecutionEngineImpl.cpp
浏览文件 @
829cb4bb
...
...
@@ -16,6 +16,11 @@
// under the License.
#include "db/engine/ExecutionEngineImpl.h"
#include <stdexcept>
#include <utility>
#include <vector>
#include "cache/CpuCacheMgr.h"
#include "cache/GpuCacheMgr.h"
#include "knowhere/common/Config.h"
...
...
@@ -33,10 +38,6 @@
#include "wrapper/VecImpl.h"
#include "wrapper/VecIndex.h"
#include <stdexcept>
#include <utility>
#include <vector>
//#define ON_SEARCH
namespace
milvus
{
namespace
engine
{
...
...
@@ -196,6 +197,10 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
index
=
GetVecIndexFactory
(
IndexType
::
SPTAG_BKT_RNT_CPU
);
break
;
}
case
EngineType
::
HNSW
:
{
index
=
GetVecIndexFactory
(
IndexType
::
HNSW
);
break
;
}
case
EngineType
::
FAISS_BIN_IDMAP
:
{
index
=
GetVecIndexFactory
(
IndexType
::
FAISS_BIN_IDMAP
);
break
;
...
...
core/src/index/knowhere/CMakeLists.txt
浏览文件 @
829cb4bb
...
...
@@ -37,6 +37,7 @@ set(index_srcs
knowhere/index/vector_index/IndexBinaryIDMAP.cpp
knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
knowhere/index/vector_index/IndexNSG.cpp
knowhere/index/vector_index/IndexHNSW.cpp
knowhere/index/vector_index/nsg/NSG.cpp
knowhere/index/vector_index/nsg/NSGIO.cpp
knowhere/index/vector_index/nsg/NSGHelper.cpp
...
...
core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp
0 → 100644
浏览文件 @
829cb4bb
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <vector>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexHNSW.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
#include "hnswlib/hnswalg.h"
#include "hnswlib/space_ip.h"
#include "hnswlib/space_l2.h"
namespace
knowhere
{
BinarySet
IndexHNSW
::
Serialize
()
{
if
(
!
index_
)
{
KNOWHERE_THROW_MSG
(
"index not initialize or trained"
);
}
try
{
MemoryIOWriter
writer
;
index_
->
saveIndex
(
writer
);
auto
data
=
std
::
make_shared
<
uint8_t
>
();
data
.
reset
(
writer
.
data_
);
BinarySet
res_set
;
res_set
.
Append
(
"HNSW"
,
data
,
writer
.
total
);
return
res_set
;
}
catch
(
std
::
exception
&
e
)
{
KNOWHERE_THROW_MSG
(
e
.
what
());
}
}
void
IndexHNSW
::
Load
(
const
BinarySet
&
index_binary
)
{
try
{
auto
binary
=
index_binary
.
GetByName
(
"HNSW"
);
MemoryIOReader
reader
;
reader
.
total
=
binary
->
size
;
reader
.
data_
=
binary
->
data
.
get
();
hnswlib
::
SpaceInterface
<
float
>*
space
;
index_
=
std
::
make_shared
<
hnswlib
::
HierarchicalNSW
<
float
>>
(
space
);
index_
->
loadIndex
(
reader
);
}
catch
(
std
::
exception
&
e
)
{
KNOWHERE_THROW_MSG
(
e
.
what
());
}
}
DatasetPtr
IndexHNSW
::
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
auto
search_cfg
=
std
::
dynamic_pointer_cast
<
HNSWCfg
>
(
config
);
if
(
search_cfg
!=
nullptr
)
{
search_cfg
->
CheckValid
();
// throw exception
}
if
(
!
index_
)
{
KNOWHERE_THROW_MSG
(
"index not initialize or trained"
);
}
GETTENSOR
(
dataset
)
using
P
=
std
::
pair
<
float
,
int64_t
>
;
auto
compare
=
[](
P
v1
,
P
v2
)
{
return
v1
.
second
<
v2
.
second
;
};
std
::
vector
<
std
::
pair
<
float
,
int64_t
>>
ret
=
index_
->
searchKnn
(
p_data
,
search_cfg
->
k
,
compare
);
std
::
vector
<
float
>
dist
(
ret
.
size
());
std
::
vector
<
int64_t
>
ids
(
ret
.
size
());
std
::
transform
(
ret
.
begin
(),
ret
.
end
(),
std
::
back_inserter
(
dist
),
[](
const
std
::
pair
<
float
,
int64_t
>&
e
)
{
return
e
.
first
;
});
std
::
transform
(
ret
.
begin
(),
ret
.
end
(),
std
::
back_inserter
(
ids
),
[](
const
std
::
pair
<
float
,
int64_t
>&
e
)
{
return
e
.
second
;
});
auto
elems
=
rows
*
search_cfg
->
k
;
assert
(
elems
==
ret
.
size
());
size_t
p_id_size
=
sizeof
(
int64_t
)
*
elems
;
size_t
p_dist_size
=
sizeof
(
float
)
*
elems
;
auto
p_id
=
(
int64_t
*
)
malloc
(
p_id_size
);
auto
p_dist
=
(
float
*
)
malloc
(
p_dist_size
);
memcpy
(
p_dist
,
dist
.
data
(),
dist
.
size
()
*
sizeof
(
float
));
memcpy
(
p_id
,
ids
.
data
(),
ids
.
size
()
*
sizeof
(
int64_t
));
auto
ret_ds
=
std
::
make_shared
<
Dataset
>
();
ret_ds
->
Set
(
meta
::
IDS
,
p_id
);
ret_ds
->
Set
(
meta
::
DISTANCE
,
p_dist
);
}
IndexModelPtr
IndexHNSW
::
Train
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
auto
build_cfg
=
std
::
dynamic_pointer_cast
<
HNSWCfg
>
(
config
);
if
(
build_cfg
!=
nullptr
)
{
build_cfg
->
CheckValid
();
// throw exception
}
GETTENSOR
(
dataset
)
hnswlib
::
SpaceInterface
<
float
>*
space
;
if
(
config
->
metric_type
==
METRICTYPE
::
L2
)
{
space
=
new
hnswlib
::
L2Space
(
dim
);
}
else
if
(
config
->
metric_type
==
METRICTYPE
::
IP
)
{
space
=
new
hnswlib
::
InnerProductSpace
(
dim
);
}
index_
=
std
::
make_shared
<
hnswlib
::
HierarchicalNSW
<
float
>>
(
space
,
rows
,
build_cfg
->
M
,
build_cfg
->
ef
);
return
nullptr
;
}
void
IndexHNSW
::
Add
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
if
(
!
index_
)
{
KNOWHERE_THROW_MSG
(
"index not initialize"
);
}
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
GETTENSOR
(
dataset
)
auto
p_ids
=
dataset
->
Get
<
const
int64_t
*>
(
meta
::
IDS
);
for
(
int
i
=
0
;
i
<
1
;
i
++
)
{
index_
->
addPoint
((
void
*
)(
p_data
+
dim
*
i
),
p_ids
[
i
]);
}
#pragma omp parallel for
for
(
int
i
=
1
;
i
<
rows
;
i
++
)
{
index_
->
addPoint
((
void
*
)(
p_data
+
dim
*
i
),
p_ids
[
i
]);
}
}
void
IndexHNSW
::
Seal
()
{
// do nothing
}
int64_t
IndexHNSW
::
Count
()
{
if
(
!
index_
)
{
KNOWHERE_THROW_MSG
(
"index not initialize"
);
}
return
index_
->
cur_element_count
;
}
int64_t
IndexHNSW
::
Dimension
()
{
if
(
!
index_
)
{
KNOWHERE_THROW_MSG
(
"index not initialize"
);
}
return
(
*
(
size_t
*
)
index_
->
dist_func_param_
);
}
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h
0 → 100644
浏览文件 @
829cb4bb
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <mutex>
#include "hnswlib/hnswlib.h"
#include "knowhere/index/vector_index/VectorIndex.h"
namespace
knowhere
{
class
IndexHNSW
:
public
VectorIndex
{
public:
BinarySet
Serialize
()
override
;
void
Load
(
const
BinarySet
&
index_binary
)
override
;
DatasetPtr
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
// void
// set_preprocessor(PreprocessorPtr preprocessor) override;
//
// void
// set_index_model(IndexModelPtr model) override;
//
// PreprocessorPtr
// BuildPreprocessor(const DatasetPtr& dataset, const Config& config) override;
IndexModelPtr
Train
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
void
Add
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
void
Seal
()
override
;
int64_t
Count
()
override
;
int64_t
Dimension
()
override
;
private:
std
::
mutex
mutex_
;
std
::
shared_ptr
<
hnswlib
::
HierarchicalNSW
<
float
>>
index_
;
};
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h
浏览文件 @
829cb4bb
...
...
@@ -28,6 +28,12 @@ struct MemoryIOWriter : public faiss::IOWriter {
size_t
operator
()(
const
void
*
ptr
,
size_t
size
,
size_t
nitems
)
override
;
template
<
typename
T
>
size_t
write
(
T
*
ptr
,
size_t
size
,
size_t
nitems
=
1
)
{
operator
()((
const
void
*
)
ptr
,
size
,
nitems
);
}
};
struct
MemoryIOReader
:
public
faiss
::
IOReader
{
...
...
@@ -37,6 +43,12 @@ struct MemoryIOReader : public faiss::IOReader {
size_t
operator
()(
void
*
ptr
,
size_t
size
,
size_t
nitems
)
override
;
template
<
typename
T
>
size_t
read
(
T
*
ptr
,
size_t
size
,
size_t
nitems
=
1
)
{
operator
()((
void
*
)
ptr
,
size
,
nitems
);
}
};
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h
浏览文件 @
829cb4bb
...
...
@@ -68,6 +68,10 @@ constexpr int64_t DEFAULT_BKTNUMBER = INVALID_VALUE;
constexpr
int64_t
DEFAULT_BKTKMEANSK
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_BKTLEAFSIZE
=
INVALID_VALUE
;
// HNSW Config
constexpr
int64_t
DEFAULT_M
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_EF
=
INVALID_VALUE
;
struct
IVFCfg
:
public
Cfg
{
int64_t
nlist
=
DEFAULT_NLIST
;
int64_t
nprobe
=
DEFAULT_NPROBE
;
...
...
@@ -242,4 +246,12 @@ struct BinIDMAPCfg : public Cfg {
}
};
struct
HNSWCfg
:
public
Cfg
{
int64_t
M
=
DEFAULT_M
;
int64_t
ef
=
DEFAULT_EF
;
HNSWCfg
()
=
default
;
};
using
HNSWConfig
=
std
::
shared_ptr
<
HNSWCfg
>
;
}
// namespace knowhere
core/src/index/thirdparty/hnswlib/bruteforce.h
0 → 100644
浏览文件 @
829cb4bb
#pragma once
#include <unordered_map>
#include <fstream>
#include <mutex>
#include <algorithm>
namespace
hnswlib
{
template
<
typename
dist_t
>
class
BruteforceSearch
:
public
AlgorithmInterface
<
dist_t
>
{
public:
BruteforceSearch
(
SpaceInterface
<
dist_t
>
*
s
)
{
}
BruteforceSearch
(
SpaceInterface
<
dist_t
>
*
s
,
const
std
::
string
&
location
)
{
loadIndex
(
location
,
s
);
}
BruteforceSearch
(
SpaceInterface
<
dist_t
>
*
s
,
size_t
maxElements
)
{
maxelements_
=
maxElements
;
data_size_
=
s
->
get_data_size
();
fstdistfunc_
=
s
->
get_dist_func
();
dist_func_param_
=
s
->
get_dist_func_param
();
size_per_element_
=
data_size_
+
sizeof
(
labeltype
);
data_
=
(
char
*
)
malloc
(
maxElements
*
size_per_element_
);
if
(
data_
==
nullptr
)
std
::
runtime_error
(
"Not enough memory: BruteforceSearch failed to allocate data"
);
cur_element_count
=
0
;
}
~
BruteforceSearch
()
{
free
(
data_
);
}
char
*
data_
;
size_t
maxelements_
;
size_t
cur_element_count
;
size_t
size_per_element_
;
size_t
data_size_
;
DISTFUNC
<
dist_t
>
fstdistfunc_
;
void
*
dist_func_param_
;
std
::
mutex
index_lock
;
std
::
unordered_map
<
labeltype
,
size_t
>
dict_external_to_internal
;
void
addPoint
(
const
void
*
datapoint
,
labeltype
label
)
{
int
idx
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
index_lock
);
auto
search
=
dict_external_to_internal
.
find
(
label
);
if
(
search
!=
dict_external_to_internal
.
end
())
{
idx
=
search
->
second
;
}
else
{
if
(
cur_element_count
>=
maxelements_
)
{
throw
std
::
runtime_error
(
"The number of elements exceeds the specified limit
\n
"
);
}
idx
=
cur_element_count
;
dict_external_to_internal
[
label
]
=
idx
;
cur_element_count
++
;
}
}
memcpy
(
data_
+
size_per_element_
*
idx
+
data_size_
,
&
label
,
sizeof
(
labeltype
));
memcpy
(
data_
+
size_per_element_
*
idx
,
datapoint
,
data_size_
);
};
void
removePoint
(
labeltype
cur_external
)
{
size_t
cur_c
=
dict_external_to_internal
[
cur_external
];
dict_external_to_internal
.
erase
(
cur_external
);
labeltype
label
=*
((
labeltype
*
)(
data_
+
size_per_element_
*
(
cur_element_count
-
1
)
+
data_size_
));
dict_external_to_internal
[
label
]
=
cur_c
;
memcpy
(
data_
+
size_per_element_
*
cur_c
,
data_
+
size_per_element_
*
(
cur_element_count
-
1
),
data_size_
+
sizeof
(
labeltype
));
cur_element_count
--
;
}
std
::
priority_queue
<
std
::
pair
<
dist_t
,
labeltype
>>
searchKnn
(
const
void
*
query_data
,
size_t
k
)
const
{
std
::
priority_queue
<
std
::
pair
<
dist_t
,
labeltype
>>
topResults
;
if
(
cur_element_count
==
0
)
return
topResults
;
for
(
int
i
=
0
;
i
<
k
;
i
++
)
{
dist_t
dist
=
fstdistfunc_
(
query_data
,
data_
+
size_per_element_
*
i
,
dist_func_param_
);
topResults
.
push
(
std
::
pair
<
dist_t
,
labeltype
>
(
dist
,
*
((
labeltype
*
)
(
data_
+
size_per_element_
*
i
+
data_size_
))));
}
dist_t
lastdist
=
topResults
.
top
().
first
;
for
(
int
i
=
k
;
i
<
cur_element_count
;
i
++
)
{
dist_t
dist
=
fstdistfunc_
(
query_data
,
data_
+
size_per_element_
*
i
,
dist_func_param_
);
if
(
dist
<=
lastdist
)
{
topResults
.
push
(
std
::
pair
<
dist_t
,
labeltype
>
(
dist
,
*
((
labeltype
*
)
(
data_
+
size_per_element_
*
i
+
data_size_
))));
if
(
topResults
.
size
()
>
k
)
topResults
.
pop
();
lastdist
=
topResults
.
top
().
first
;
}
}
return
topResults
;
};
template
<
typename
Comp
>
std
::
vector
<
std
::
pair
<
dist_t
,
labeltype
>>
searchKnn
(
const
void
*
query_data
,
size_t
k
,
Comp
comp
)
{
std
::
vector
<
std
::
pair
<
dist_t
,
labeltype
>>
result
;
if
(
cur_element_count
==
0
)
return
result
;
auto
ret
=
searchKnn
(
query_data
,
k
);
while
(
!
ret
.
empty
())
{
result
.
push_back
(
ret
.
top
());
ret
.
pop
();
}
std
::
sort
(
result
.
begin
(),
result
.
end
(),
comp
);
return
result
;
}
void
saveIndex
(
const
std
::
string
&
location
)
{
std
::
ofstream
output
(
location
,
std
::
ios
::
binary
);
std
::
streampos
position
;
writeBinaryPOD
(
output
,
maxelements_
);
writeBinaryPOD
(
output
,
size_per_element_
);
writeBinaryPOD
(
output
,
cur_element_count
);
output
.
write
(
data_
,
maxelements_
*
size_per_element_
);
output
.
close
();
}
void
loadIndex
(
const
std
::
string
&
location
,
SpaceInterface
<
dist_t
>
*
s
)
{
std
::
ifstream
input
(
location
,
std
::
ios
::
binary
);
std
::
streampos
position
;
readBinaryPOD
(
input
,
maxelements_
);
readBinaryPOD
(
input
,
size_per_element_
);
readBinaryPOD
(
input
,
cur_element_count
);
data_size_
=
s
->
get_data_size
();
fstdistfunc_
=
s
->
get_dist_func
();
dist_func_param_
=
s
->
get_dist_func_param
();
size_per_element_
=
data_size_
+
sizeof
(
labeltype
);
data_
=
(
char
*
)
malloc
(
maxelements_
*
size_per_element_
);
if
(
data_
==
nullptr
)
std
::
runtime_error
(
"Not enough memory: loadIndex failed to allocate data"
);
input
.
read
(
data_
,
maxelements_
*
size_per_element_
);
input
.
close
();
}
};
}
core/src/index/thirdparty/hnswlib/hnswalg.h
0 → 100644
浏览文件 @
829cb4bb
此差异已折叠。
点击以展开。
core/src/index/thirdparty/hnswlib/hnswlib.h
0 → 100644
浏览文件 @
829cb4bb
#pragma once
#ifndef NO_MANUAL_VECTORIZATION
#ifdef __SSE__
#define USE_SSE
#ifdef __AVX__
#define USE_AVX
#endif
#endif
#endif
#if defined(USE_AVX) || defined(USE_SSE)
#ifdef _MSC_VER
#include <intrin.h>
#include <stdexcept>
#else
#include <x86intrin.h>
#endif
#if defined(__GNUC__)
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
#else
#define PORTABLE_ALIGN32 __declspec(align(32))
#endif
#endif
#include <queue>
#include <vector>
#include <string.h>
namespace
hnswlib
{
typedef
int64_t
labeltype
;
template
<
typename
T
>
class
pairGreater
{
public:
bool
operator
()(
const
T
&
p1
,
const
T
&
p2
)
{
return
p1
.
first
>
p2
.
first
;
}
};
template
<
typename
T
>
static
void
writeBinaryPOD
(
std
::
ostream
&
out
,
const
T
&
podRef
)
{
out
.
write
((
char
*
)
&
podRef
,
sizeof
(
T
));
}
template
<
typename
T
>
static
void
readBinaryPOD
(
std
::
istream
&
in
,
T
&
podRef
)
{
in
.
read
((
char
*
)
&
podRef
,
sizeof
(
T
));
}
template
<
typename
T
,
typename
W
>
static
void
writeBinaryPOD
(
W
&
out
,
const
T
&
podRef
)
{
out
.
write
((
char
*
)
&
podRef
,
sizeof
(
T
));
}
template
<
typename
T
,
typename
R
>
static
void
readBinaryPOD
(
R
&
in
,
T
&
podRef
)
{
in
.
read
((
char
*
)
&
podRef
,
sizeof
(
T
));
}
template
<
typename
MTYPE
>
using
DISTFUNC
=
MTYPE
(
*
)(
const
void
*
,
const
void
*
,
const
void
*
);
template
<
typename
MTYPE
>
class
SpaceInterface
{
public:
//virtual void search(void *);
virtual
size_t
get_data_size
()
=
0
;
virtual
DISTFUNC
<
MTYPE
>
get_dist_func
()
=
0
;
virtual
void
*
get_dist_func_param
()
=
0
;
virtual
~
SpaceInterface
()
{}
};
template
<
typename
dist_t
>
class
AlgorithmInterface
{
public:
virtual
void
addPoint
(
const
void
*
datapoint
,
labeltype
label
)
=
0
;
virtual
std
::
priority_queue
<
std
::
pair
<
dist_t
,
labeltype
>>
searchKnn
(
const
void
*
,
size_t
)
const
=
0
;
template
<
typename
Comp
>
std
::
vector
<
std
::
pair
<
dist_t
,
labeltype
>>
searchKnn
(
const
void
*
,
size_t
,
Comp
)
{
}
virtual
void
saveIndex
(
const
std
::
string
&
location
)
=
0
;
virtual
~
AlgorithmInterface
(){
}
};
}
#include "space_l2.h"
#include "space_ip.h"
#include "bruteforce.h"
#include "hnswalg.h"
core/src/index/thirdparty/hnswlib/space_ip.h
0 → 100644
浏览文件 @
829cb4bb
#pragma once
#include "hnswlib.h"
namespace
hnswlib
{
static
float
InnerProduct
(
const
void
*
pVect1
,
const
void
*
pVect2
,
const
void
*
qty_ptr
)
{
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
float
res
=
0
;
for
(
unsigned
i
=
0
;
i
<
qty
;
i
++
)
{
res
+=
((
float
*
)
pVect1
)[
i
]
*
((
float
*
)
pVect2
)[
i
];
}
return
(
1.0
f
-
res
);
}
#if defined(USE_AVX)
// Favor using AVX if available.
static
float
InnerProductSIMD4Ext
(
const
void
*
pVect1v
,
const
void
*
pVect2v
,
const
void
*
qty_ptr
)
{
float
PORTABLE_ALIGN32
TmpRes
[
8
];
float
*
pVect1
=
(
float
*
)
pVect1v
;
float
*
pVect2
=
(
float
*
)
pVect2v
;
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
size_t
qty16
=
qty
/
16
;
size_t
qty4
=
qty
/
4
;
const
float
*
pEnd1
=
pVect1
+
16
*
qty16
;
const
float
*
pEnd2
=
pVect1
+
4
*
qty4
;
__m256
sum256
=
_mm256_set1_ps
(
0
);
while
(
pVect1
<
pEnd1
)
{
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m256
v1
=
_mm256_loadu_ps
(
pVect1
);
pVect1
+=
8
;
__m256
v2
=
_mm256_loadu_ps
(
pVect2
);
pVect2
+=
8
;
sum256
=
_mm256_add_ps
(
sum256
,
_mm256_mul_ps
(
v1
,
v2
));
v1
=
_mm256_loadu_ps
(
pVect1
);
pVect1
+=
8
;
v2
=
_mm256_loadu_ps
(
pVect2
);
pVect2
+=
8
;
sum256
=
_mm256_add_ps
(
sum256
,
_mm256_mul_ps
(
v1
,
v2
));
}
__m128
v1
,
v2
;
__m128
sum_prod
=
_mm_add_ps
(
_mm256_extractf128_ps
(
sum256
,
0
),
_mm256_extractf128_ps
(
sum256
,
1
));
while
(
pVect1
<
pEnd2
)
{
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
}
_mm_store_ps
(
TmpRes
,
sum_prod
);
float
sum
=
TmpRes
[
0
]
+
TmpRes
[
1
]
+
TmpRes
[
2
]
+
TmpRes
[
3
];;
return
1.0
f
-
sum
;
}
#elif defined(USE_SSE)
static
float
InnerProductSIMD4Ext
(
const
void
*
pVect1v
,
const
void
*
pVect2v
,
const
void
*
qty_ptr
)
{
float
PORTABLE_ALIGN32
TmpRes
[
8
];
float
*
pVect1
=
(
float
*
)
pVect1v
;
float
*
pVect2
=
(
float
*
)
pVect2v
;
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
size_t
qty16
=
qty
/
16
;
size_t
qty4
=
qty
/
4
;
const
float
*
pEnd1
=
pVect1
+
16
*
qty16
;
const
float
*
pEnd2
=
pVect1
+
4
*
qty4
;
__m128
v1
,
v2
;
__m128
sum_prod
=
_mm_set1_ps
(
0
);
while
(
pVect1
<
pEnd1
)
{
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
}
while
(
pVect1
<
pEnd2
)
{
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
}
_mm_store_ps
(
TmpRes
,
sum_prod
);
float
sum
=
TmpRes
[
0
]
+
TmpRes
[
1
]
+
TmpRes
[
2
]
+
TmpRes
[
3
];
return
1.0
f
-
sum
;
}
#endif
#if defined(USE_AVX)
static
float
InnerProductSIMD16Ext
(
const
void
*
pVect1v
,
const
void
*
pVect2v
,
const
void
*
qty_ptr
)
{
float
PORTABLE_ALIGN32
TmpRes
[
8
];
float
*
pVect1
=
(
float
*
)
pVect1v
;
float
*
pVect2
=
(
float
*
)
pVect2v
;
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
size_t
qty16
=
qty
/
16
;
const
float
*
pEnd1
=
pVect1
+
16
*
qty16
;
__m256
sum256
=
_mm256_set1_ps
(
0
);
while
(
pVect1
<
pEnd1
)
{
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m256
v1
=
_mm256_loadu_ps
(
pVect1
);
pVect1
+=
8
;
__m256
v2
=
_mm256_loadu_ps
(
pVect2
);
pVect2
+=
8
;
sum256
=
_mm256_add_ps
(
sum256
,
_mm256_mul_ps
(
v1
,
v2
));
v1
=
_mm256_loadu_ps
(
pVect1
);
pVect1
+=
8
;
v2
=
_mm256_loadu_ps
(
pVect2
);
pVect2
+=
8
;
sum256
=
_mm256_add_ps
(
sum256
,
_mm256_mul_ps
(
v1
,
v2
));
}
_mm256_store_ps
(
TmpRes
,
sum256
);
float
sum
=
TmpRes
[
0
]
+
TmpRes
[
1
]
+
TmpRes
[
2
]
+
TmpRes
[
3
]
+
TmpRes
[
4
]
+
TmpRes
[
5
]
+
TmpRes
[
6
]
+
TmpRes
[
7
];
return
1.0
f
-
sum
;
}
#elif defined(USE_SSE)
static
float
InnerProductSIMD16Ext
(
const
void
*
pVect1v
,
const
void
*
pVect2v
,
const
void
*
qty_ptr
)
{
float
PORTABLE_ALIGN32
TmpRes
[
8
];
float
*
pVect1
=
(
float
*
)
pVect1v
;
float
*
pVect2
=
(
float
*
)
pVect2v
;
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
size_t
qty16
=
qty
/
16
;
const
float
*
pEnd1
=
pVect1
+
16
*
qty16
;
__m128
v1
,
v2
;
__m128
sum_prod
=
_mm_set1_ps
(
0
);
while
(
pVect1
<
pEnd1
)
{
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
sum_prod
=
_mm_add_ps
(
sum_prod
,
_mm_mul_ps
(
v1
,
v2
));
}
_mm_store_ps
(
TmpRes
,
sum_prod
);
float
sum
=
TmpRes
[
0
]
+
TmpRes
[
1
]
+
TmpRes
[
2
]
+
TmpRes
[
3
];
return
1.0
f
-
sum
;
}
#endif
class
InnerProductSpace
:
public
SpaceInterface
<
float
>
{
DISTFUNC
<
float
>
fstdistfunc_
;
size_t
data_size_
;
size_t
dim_
;
public:
InnerProductSpace
(
size_t
dim
)
{
fstdistfunc_
=
InnerProduct
;
#if defined(USE_AVX) || defined(USE_SSE)
if
(
dim
%
4
==
0
)
fstdistfunc_
=
InnerProductSIMD4Ext
;
if
(
dim
%
16
==
0
)
fstdistfunc_
=
InnerProductSIMD16Ext
;
#endif
dim_
=
dim
;
data_size_
=
dim
*
sizeof
(
float
);
}
size_t
get_data_size
()
{
return
data_size_
;
}
DISTFUNC
<
float
>
get_dist_func
()
{
return
fstdistfunc_
;
}
void
*
get_dist_func_param
()
{
return
&
dim_
;
}
~
InnerProductSpace
()
{}
};
}
core/src/index/thirdparty/hnswlib/space_l2.h
0 → 100644
浏览文件 @
829cb4bb
#pragma once
#include "hnswlib.h"
namespace
hnswlib
{
static
float
L2Sqr
(
const
void
*
pVect1
,
const
void
*
pVect2
,
const
void
*
qty_ptr
)
{
//return *((float *)pVect2);
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
float
res
=
0
;
for
(
unsigned
i
=
0
;
i
<
qty
;
i
++
)
{
float
t
=
((
float
*
)
pVect1
)[
i
]
-
((
float
*
)
pVect2
)[
i
];
res
+=
t
*
t
;
}
return
(
res
);
}
#if defined(USE_AVX)
// Favor using AVX if available.
static
float
L2SqrSIMD16Ext
(
const
void
*
pVect1v
,
const
void
*
pVect2v
,
const
void
*
qty_ptr
)
{
float
*
pVect1
=
(
float
*
)
pVect1v
;
float
*
pVect2
=
(
float
*
)
pVect2v
;
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
float
PORTABLE_ALIGN32
TmpRes
[
8
];
size_t
qty16
=
qty
>>
4
;
const
float
*
pEnd1
=
pVect1
+
(
qty16
<<
4
);
__m256
diff
,
v1
,
v2
;
__m256
sum
=
_mm256_set1_ps
(
0
);
while
(
pVect1
<
pEnd1
)
{
v1
=
_mm256_loadu_ps
(
pVect1
);
pVect1
+=
8
;
v2
=
_mm256_loadu_ps
(
pVect2
);
pVect2
+=
8
;
diff
=
_mm256_sub_ps
(
v1
,
v2
);
sum
=
_mm256_add_ps
(
sum
,
_mm256_mul_ps
(
diff
,
diff
));
v1
=
_mm256_loadu_ps
(
pVect1
);
pVect1
+=
8
;
v2
=
_mm256_loadu_ps
(
pVect2
);
pVect2
+=
8
;
diff
=
_mm256_sub_ps
(
v1
,
v2
);
sum
=
_mm256_add_ps
(
sum
,
_mm256_mul_ps
(
diff
,
diff
));
}
_mm256_store_ps
(
TmpRes
,
sum
);
float
res
=
TmpRes
[
0
]
+
TmpRes
[
1
]
+
TmpRes
[
2
]
+
TmpRes
[
3
]
+
TmpRes
[
4
]
+
TmpRes
[
5
]
+
TmpRes
[
6
]
+
TmpRes
[
7
];
return
(
res
);
}
#elif defined(USE_SSE)
static
float
L2SqrSIMD16Ext
(
const
void
*
pVect1v
,
const
void
*
pVect2v
,
const
void
*
qty_ptr
)
{
float
*
pVect1
=
(
float
*
)
pVect1v
;
float
*
pVect2
=
(
float
*
)
pVect2v
;
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
float
PORTABLE_ALIGN32
TmpRes
[
8
];
// size_t qty4 = qty >> 2;
size_t
qty16
=
qty
>>
4
;
const
float
*
pEnd1
=
pVect1
+
(
qty16
<<
4
);
// const float* pEnd2 = pVect1 + (qty4 << 2);
// const float* pEnd3 = pVect1 + qty;
__m128
diff
,
v1
,
v2
;
__m128
sum
=
_mm_set1_ps
(
0
);
while
(
pVect1
<
pEnd1
)
{
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
diff
=
_mm_sub_ps
(
v1
,
v2
);
sum
=
_mm_add_ps
(
sum
,
_mm_mul_ps
(
diff
,
diff
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
diff
=
_mm_sub_ps
(
v1
,
v2
);
sum
=
_mm_add_ps
(
sum
,
_mm_mul_ps
(
diff
,
diff
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
diff
=
_mm_sub_ps
(
v1
,
v2
);
sum
=
_mm_add_ps
(
sum
,
_mm_mul_ps
(
diff
,
diff
));
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
diff
=
_mm_sub_ps
(
v1
,
v2
);
sum
=
_mm_add_ps
(
sum
,
_mm_mul_ps
(
diff
,
diff
));
}
_mm_store_ps
(
TmpRes
,
sum
);
float
res
=
TmpRes
[
0
]
+
TmpRes
[
1
]
+
TmpRes
[
2
]
+
TmpRes
[
3
];
return
(
res
);
}
#endif
#ifdef USE_SSE
static
float
L2SqrSIMD4Ext
(
const
void
*
pVect1v
,
const
void
*
pVect2v
,
const
void
*
qty_ptr
)
{
float
PORTABLE_ALIGN32
TmpRes
[
8
];
float
*
pVect1
=
(
float
*
)
pVect1v
;
float
*
pVect2
=
(
float
*
)
pVect2v
;
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
// size_t qty4 = qty >> 2;
size_t
qty16
=
qty
>>
2
;
const
float
*
pEnd1
=
pVect1
+
(
qty16
<<
2
);
__m128
diff
,
v1
,
v2
;
__m128
sum
=
_mm_set1_ps
(
0
);
while
(
pVect1
<
pEnd1
)
{
v1
=
_mm_loadu_ps
(
pVect1
);
pVect1
+=
4
;
v2
=
_mm_loadu_ps
(
pVect2
);
pVect2
+=
4
;
diff
=
_mm_sub_ps
(
v1
,
v2
);
sum
=
_mm_add_ps
(
sum
,
_mm_mul_ps
(
diff
,
diff
));
}
_mm_store_ps
(
TmpRes
,
sum
);
float
res
=
TmpRes
[
0
]
+
TmpRes
[
1
]
+
TmpRes
[
2
]
+
TmpRes
[
3
];
return
(
res
);
}
#endif
class
L2Space
:
public
SpaceInterface
<
float
>
{
DISTFUNC
<
float
>
fstdistfunc_
;
size_t
data_size_
;
size_t
dim_
;
public:
L2Space
(
size_t
dim
)
{
fstdistfunc_
=
L2Sqr
;
#if defined(USE_SSE) || defined(USE_AVX)
if
(
dim
%
4
==
0
)
fstdistfunc_
=
L2SqrSIMD4Ext
;
if
(
dim
%
16
==
0
)
fstdistfunc_
=
L2SqrSIMD16Ext
;
/*else{
throw runtime_error("Data type not supported!");
}*/
#endif
dim_
=
dim
;
data_size_
=
dim
*
sizeof
(
float
);
}
size_t
get_data_size
()
{
return
data_size_
;
}
DISTFUNC
<
float
>
get_dist_func
()
{
return
fstdistfunc_
;
}
void
*
get_dist_func_param
()
{
return
&
dim_
;
}
~
L2Space
()
{}
};
static
int
L2SqrI
(
const
void
*
__restrict
pVect1
,
const
void
*
__restrict
pVect2
,
const
void
*
__restrict
qty_ptr
)
{
size_t
qty
=
*
((
size_t
*
)
qty_ptr
);
int
res
=
0
;
unsigned
char
*
a
=
(
unsigned
char
*
)
pVect1
;
unsigned
char
*
b
=
(
unsigned
char
*
)
pVect2
;
/*for (int i = 0; i < qty; i++) {
int t = int((a)[i]) - int((b)[i]);
res += t*t;
}*/
qty
=
qty
>>
2
;
for
(
size_t
i
=
0
;
i
<
qty
;
i
++
)
{
res
+=
((
*
a
)
-
(
*
b
))
*
((
*
a
)
-
(
*
b
));
a
++
;
b
++
;
res
+=
((
*
a
)
-
(
*
b
))
*
((
*
a
)
-
(
*
b
));
a
++
;
b
++
;
res
+=
((
*
a
)
-
(
*
b
))
*
((
*
a
)
-
(
*
b
));
a
++
;
b
++
;
res
+=
((
*
a
)
-
(
*
b
))
*
((
*
a
)
-
(
*
b
));
a
++
;
b
++
;
}
return
(
res
);
}
class
L2SpaceI
:
public
SpaceInterface
<
int
>
{
DISTFUNC
<
int
>
fstdistfunc_
;
size_t
data_size_
;
size_t
dim_
;
public:
L2SpaceI
(
size_t
dim
)
{
fstdistfunc_
=
L2SqrI
;
dim_
=
dim
;
data_size_
=
dim
*
sizeof
(
unsigned
char
);
}
size_t
get_data_size
()
{
return
data_size_
;
}
DISTFUNC
<
int
>
get_dist_func
()
{
return
fstdistfunc_
;
}
void
*
get_dist_func_param
()
{
return
&
dim_
;
}
~
L2SpaceI
()
{}
};
}
core/src/index/thirdparty/hnswlib/visited_list_pool.h
0 → 100644
浏览文件 @
829cb4bb
#pragma once
#include <mutex>
#include <string.h>
namespace
hnswlib
{
typedef
unsigned
short
int
vl_type
;
class
VisitedList
{
public:
vl_type
curV
;
vl_type
*
mass
;
unsigned
int
numelements
;
VisitedList
(
int
numelements1
)
{
curV
=
-
1
;
numelements
=
numelements1
;
mass
=
new
vl_type
[
numelements
];
}
void
reset
()
{
curV
++
;
if
(
curV
==
0
)
{
memset
(
mass
,
0
,
sizeof
(
vl_type
)
*
numelements
);
curV
++
;
}
};
~
VisitedList
()
{
delete
[]
mass
;
}
};
///////////////////////////////////////////////////////////
//
// Class for multi-threaded pool-management of VisitedLists
//
/////////////////////////////////////////////////////////
class
VisitedListPool
{
std
::
deque
<
VisitedList
*>
pool
;
std
::
mutex
poolguard
;
int
numelements
;
public:
VisitedListPool
(
int
initmaxpools
,
int
numelements1
)
{
numelements
=
numelements1
;
for
(
int
i
=
0
;
i
<
initmaxpools
;
i
++
)
pool
.
push_front
(
new
VisitedList
(
numelements
));
}
VisitedList
*
getFreeVisitedList
()
{
VisitedList
*
rez
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
poolguard
);
if
(
pool
.
size
()
>
0
)
{
rez
=
pool
.
front
();
pool
.
pop_front
();
}
else
{
rez
=
new
VisitedList
(
numelements
);
}
}
rez
->
reset
();
return
rez
;
};
void
releaseVisitedList
(
VisitedList
*
vl
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
poolguard
);
pool
.
push_front
(
vl
);
};
~
VisitedListPool
()
{
while
(
pool
.
size
())
{
VisitedList
*
rez
=
pool
.
front
();
pool
.
pop_front
();
delete
rez
;
}
};
};
}
core/src/wrapper/ConfAdapter.cpp
浏览文件 @
829cb4bb
...
...
@@ -16,15 +16,16 @@
// under the License.
#include "wrapper/ConfAdapter.h"
#include "WrapperException.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "server/Config.h"
#include "utils/Log.h"
#include <cmath>
#include <memory>
#include <vector>
#include "WrapperException.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "server/Config.h"
#include "utils/Log.h"
// TODO(lxj): add conf checker
namespace
milvus
{
...
...
@@ -266,6 +267,17 @@ SPTAGBKTConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType&
return
conf
;
}
knowhere
::
Config
HNSWConfAdapter
::
Match
(
const
TempMetaConf
&
metaconf
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
HNSWCfg
>
();
conf
->
d
=
metaconf
.
dim
;
conf
->
metric_type
=
metaconf
.
metric_type
;
conf
->
ef
=
100
;
// ef can be auto-configured by using sample data.
conf
->
M
=
16
;
// A reasonable range of M is from 5 to 48.
return
conf
;
}
knowhere
::
Config
BinIDMAPConfAdapter
::
Match
(
const
TempMetaConf
&
metaconf
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
BinIDMAPCfg
>
();
...
...
core/src/wrapper/ConfAdapter.h
浏览文件 @
829cb4bb
...
...
@@ -17,11 +17,11 @@
#pragma once
#include <memory>
#include "VecIndex.h"
#include "knowhere/common/Config.h"
#include <memory>
namespace
milvus
{
namespace
engine
{
...
...
@@ -124,5 +124,11 @@ class BinIVFConfAdapter : public IVFConfAdapter {
Match
(
const
TempMetaConf
&
metaconf
)
override
;
};
class
HNSWConfAdapter
:
public
ConfAdapter
{
public:
knowhere
::
Config
Match
(
const
TempMetaConf
&
metaconf
)
override
;
};
}
// namespace engine
}
// namespace milvus
core/src/wrapper/ConfAdapterMgr.cpp
浏览文件 @
829cb4bb
...
...
@@ -16,6 +16,7 @@
// under the License.
#include "wrapper/ConfAdapterMgr.h"
#include "utils/Exception.h"
namespace
milvus
{
...
...
@@ -61,6 +62,8 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER
(
SPTAGKDTConfAdapter
,
IndexType
::
SPTAG_KDT_RNT_CPU
,
sptag_kdt
);
REGISTER_CONF_ADAPTER
(
SPTAGBKTConfAdapter
,
IndexType
::
SPTAG_BKT_RNT_CPU
,
sptag_bkt
);
REGISTER_CONF_ADAPTER
(
HNSWConfAdapter
,
IndexType
::
HNSW
,
hnsw
);
}
}
// namespace engine
...
...
core/src/wrapper/VecIndex.cpp
浏览文件 @
829cb4bb
...
...
@@ -16,10 +16,12 @@
// under the License.
#include "wrapper/VecIndex.h"
#include "VecImpl.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexBinaryIDMAP.h"
#include "knowhere/index/vector_index/IndexBinaryIVF.h"
#include "knowhere/index/vector_index/IndexHNSW.h"
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
...
...
@@ -38,6 +40,7 @@
#ifdef MILVUS_GPU_VERSION
#include <cuda.h>
#include "knowhere/index/vector_index/IndexGPUIDMAP.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
...
...
@@ -99,6 +102,10 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
index
=
std
::
make_shared
<
knowhere
::
IVFSQ
>
();
break
;
}
case
IndexType
::
HNSW
:
{
index
=
std
::
make_shared
<
knowhere
::
IndexHNSW
>
();
break
;
}
#ifdef MILVUS_GPU_VERSION
case
IndexType
::
FAISS_IVFFLAT_GPU
:
{
...
...
core/src/wrapper/VecIndex.h
浏览文件 @
829cb4bb
...
...
@@ -50,6 +50,7 @@ enum class IndexType {
NSG_MIX
,
FAISS_IVFPQ_MIX
,
SPTAG_BKT_RNT_CPU
,
HNSW
,
FAISS_BIN_IDMAP
=
100
,
FAISS_BIN_IVFLAT_CPU
=
101
,
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录