Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
c6ce5772
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c6ce5772
编写于
7月 10, 2019
作者:
X
xj.lin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bad alloc and add idmap
Former-commit-id: bd2686574ad9010e33dcf34f3ae45308d5b3c971
上级
4fe9622b
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
323 addition
and
307 deletion
+323
-307
cpp/src/db/ExecutionEngineImpl.cpp
cpp/src/db/ExecutionEngineImpl.cpp
+27
-130
cpp/src/wrapper/knowhere/vec_impl.cpp
cpp/src/wrapper/knowhere/vec_impl.cpp
+37
-0
cpp/src/wrapper/knowhere/vec_impl.h
cpp/src/wrapper/knowhere/vec_impl.h
+22
-5
cpp/src/wrapper/knowhere/vec_index.cpp
cpp/src/wrapper/knowhere/vec_index.cpp
+117
-6
cpp/src/wrapper/knowhere/vec_index.h
cpp/src/wrapper/knowhere/vec_index.h
+17
-10
cpp/thirdparty/knowhere
cpp/thirdparty/knowhere
+1
-1
cpp/unittest/index_wrapper/knowhere_test.cpp
cpp/unittest/index_wrapper/knowhere_test.cpp
+88
-100
cpp/unittest/index_wrapper/utils.cpp
cpp/unittest/index_wrapper/utils.cpp
+6
-31
cpp/unittest/index_wrapper/utils.h
cpp/unittest/index_wrapper/utils.h
+8
-24
未找到文件。
cpp/src/db/ExecutionEngineImpl.cpp
浏览文件 @
c6ce5772
...
...
@@ -4,6 +4,7 @@
* Proprietary and confidential.
******************************************************************************/
#include <src/server/ServerConfig.h>
#include <src/metrics/Metrics.h>
#include "Log.h"
#include "src/cache/CpuCacheMgr.h"
...
...
@@ -16,55 +17,6 @@ namespace zilliz {
namespace
milvus
{
namespace
engine
{
struct
FileIOWriter
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOWriter
(
const
std
::
string
&
fname
);
~
FileIOWriter
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
};
struct
FileIOReader
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOReader
(
const
std
::
string
&
fname
);
~
FileIOReader
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
size_t
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
);
};
FileIOReader
::
FileIOReader
(
const
std
::
string
&
fname
)
{
name
=
fname
;
fs
=
std
::
fstream
(
name
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
}
FileIOReader
::~
FileIOReader
()
{
fs
.
close
();
}
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
read
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
}
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
)
{
return
0
;
}
FileIOWriter
::
FileIOWriter
(
const
std
::
string
&
fname
)
{
name
=
fname
;
fs
=
std
::
fstream
(
name
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
}
FileIOWriter
::~
FileIOWriter
()
{
fs
.
close
();
}
size_t
FileIOWriter
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
write
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
}
ExecutionEngineImpl
::
ExecutionEngineImpl
(
uint16_t
dimension
,
const
std
::
string
&
location
,
EngineType
type
)
...
...
@@ -89,7 +41,7 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
break
;
}
case
EngineType
::
FAISS_IVFFLAT_GPU
:
{
index
=
GetVecIndexFactory
(
IndexType
::
FAISS_IVFFLAT_
GPU
);
index
=
GetVecIndexFactory
(
IndexType
::
FAISS_IVFFLAT_
MIX
);
break
;
}
case
EngineType
::
FAISS_IVFFLAT_CPU
:
{
...
...
@@ -130,89 +82,32 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
}
Status
ExecutionEngineImpl
::
Serialize
()
{
auto
binaryset
=
index_
->
Serialize
();
FileIOWriter
writer
(
location_
);
writer
(
&
current_type
,
sizeof
(
current_type
));
for
(
auto
&
iter
:
binaryset
.
binary_map_
)
{
auto
meta
=
iter
.
first
.
c_str
();
size_t
meta_length
=
iter
.
first
.
length
();
writer
(
&
meta_length
,
sizeof
(
meta_length
));
writer
((
void
*
)
meta
,
meta_length
);
auto
binary
=
iter
.
second
;
size_t
binary_length
=
binary
->
size
;
writer
(
&
binary_length
,
sizeof
(
binary_length
));
writer
((
void
*
)
binary
->
data
.
get
(),
binary_length
);
}
write_index
(
index_
,
location_
);
return
Status
::
OK
();
}
Status
ExecutionEngineImpl
::
Load
()
{
index_
=
Load
(
location_
);
return
Status
::
OK
();
}
VecIndexPtr
ExecutionEngineImpl
::
Load
(
const
std
::
string
&
location
)
{
knowhere
::
BinarySet
load_data_list
;
FileIOReader
reader
(
location
);
reader
.
fs
.
seekg
(
0
,
reader
.
fs
.
end
);
size_t
length
=
reader
.
fs
.
tellg
();
reader
.
fs
.
seekg
(
0
);
size_t
rp
=
0
;
reader
(
&
current_type
,
sizeof
(
current_type
));
rp
+=
sizeof
(
current_type
);
while
(
rp
<
length
)
{
size_t
meta_length
;
reader
(
&
meta_length
,
sizeof
(
meta_length
));
rp
+=
sizeof
(
meta_length
);
reader
.
fs
.
seekg
(
rp
);
auto
meta
=
new
char
[
meta_length
];
reader
(
meta
,
meta_length
);
rp
+=
meta_length
;
reader
.
fs
.
seekg
(
rp
);
size_t
bin_length
;
reader
(
&
bin_length
,
sizeof
(
bin_length
));
rp
+=
sizeof
(
bin_length
);
reader
.
fs
.
seekg
(
rp
);
index_
=
zilliz
::
milvus
::
cache
::
CpuCacheMgr
::
GetInstance
()
->
GetIndex
(
location_
);
bool
to_cache
=
false
;
auto
start_time
=
METRICS_NOW_TIME
;
if
(
!
index_
)
{
index_
=
read_index
(
location_
);
to_cache
=
true
;
ENGINE_LOG_DEBUG
<<
"Disk io from: "
<<
location_
;
}
auto
bin
=
new
uint8_t
[
bin_length
];
reader
(
bin
,
bin_length
);
rp
+=
bin_length
;
if
(
to_cache
)
{
Cache
();
auto
end_time
=
METRICS_NOW_TIME
;
auto
total_time
=
METRICS_MICROSECONDS
(
start_time
,
end_time
);
auto
binptr
=
std
::
make_shared
<
uint8_t
>
();
binptr
.
reset
(
bin
);
load_data_list
.
Append
(
std
::
string
(
meta
,
meta_length
),
binptr
,
bin_length
);
}
server
::
Metrics
::
GetInstance
().
FaissDiskLoadDurationSecondsHistogramObserve
(
total_time
);
double
total_size
=
Size
();
auto
index_type
=
IndexType
::
INVALID
;
switch
(
current_type
)
{
case
EngineType
::
FAISS_IDMAP
:
{
index_type
=
IndexType
::
FAISS_IDMAP
;
break
;
}
case
EngineType
::
FAISS_IVFFLAT_CPU
:
{
index_type
=
IndexType
::
FAISS_IVFFLAT_CPU
;
break
;
}
case
EngineType
::
FAISS_IVFFLAT_GPU
:
{
index_type
=
IndexType
::
FAISS_IVFFLAT_GPU
;
break
;
}
case
EngineType
::
SPTAG_KDT_RNT_CPU
:
{
index_type
=
IndexType
::
SPTAG_KDT_RNT_CPU
;
break
;
}
default:
{
ENGINE_LOG_ERROR
<<
"wrong index_type"
;
return
nullptr
;
}
server
::
Metrics
::
GetInstance
().
FaissDiskLoadSizeBytesHistogramObserve
(
total_size
);
server
::
Metrics
::
GetInstance
().
FaissDiskLoadIOSpeedGaugeSet
(
total_size
/
double
(
total_time
));
}
return
LoadVecIndex
(
index_type
,
load_data_list
);
return
Status
::
OK
();
}
Status
ExecutionEngineImpl
::
Merge
(
const
std
::
string
&
location
)
{
...
...
@@ -223,15 +118,17 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
auto
to_merge
=
zilliz
::
milvus
::
cache
::
CpuCacheMgr
::
GetInstance
()
->
GetIndex
(
location
);
if
(
!
to_merge
)
{
to_merge
=
Load
(
location
);
to_merge
=
read_index
(
location
);
}
auto
file_index
=
std
::
dynamic_pointer_cast
<
BFIndex
>
(
to_merge
);
index_
->
Add
(
file_index
->
Count
(),
file_index
->
GetRawVectors
(),
file_index
->
GetRawIds
());
return
Status
::
OK
();
if
(
auto
file_index
=
std
::
dynamic_pointer_cast
<
BFIndex
>
(
to_merge
))
{
index_
->
Add
(
file_index
->
Count
(),
file_index
->
GetRawVectors
(),
file_index
->
GetRawIds
());
return
Status
::
OK
();
}
else
{
return
Status
::
Error
(
"file index type is not idmap"
);
}
}
// TODO(linxj): add config
ExecutionEnginePtr
ExecutionEngineImpl
::
BuildIndex
(
const
std
::
string
&
location
)
{
ENGINE_LOG_DEBUG
<<
"Build index file: "
<<
location
<<
" from: "
<<
location_
;
...
...
cpp/src/wrapper/knowhere/vec_impl.cpp
浏览文件 @
c6ce5772
...
...
@@ -6,6 +6,7 @@
#include <src/utils/Log.h>
#include "knowhere/index/vector_index/idmap.h"
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "vec_impl.h"
#include "data_transfer.h"
...
...
@@ -98,6 +99,10 @@ int64_t VecIndexImpl::Count() {
return
index_
->
Count
();
}
IndexType
VecIndexImpl
::
GetType
()
{
return
type
;
}
float
*
BFIndex
::
GetRawVectors
()
{
auto
raw_index
=
std
::
dynamic_pointer_cast
<
IDMAP
>
(
index_
);
if
(
raw_index
)
{
return
raw_index
->
GetRawVectors
();
}
...
...
@@ -126,6 +131,38 @@ void BFIndex::BuildAll(const long &nb,
index_
->
Add
(
dataset
,
cfg
);
}
// TODO(linxj): add lock here.
void
IVFMixIndex
::
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
,
const
long
&
nt
,
const
float
*
xt
)
{
dim
=
cfg
[
"dim"
].
as
<
int
>
();
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
auto
preprocessor
=
index_
->
BuildPreprocessor
(
dataset
,
cfg
);
index_
->
set_preprocessor
(
preprocessor
);
auto
nlist
=
int
(
nb
/
1000000.0
*
16384
);
auto
cfg_t
=
Config
::
object
{{
"nlist"
,
nlist
},
{
"dim"
,
dim
}};
auto
model
=
index_
->
Train
(
dataset
,
cfg_t
);
index_
->
set_index_model
(
model
);
index_
->
Add
(
dataset
,
cfg
);
if
(
auto
device_index
=
std
::
dynamic_pointer_cast
<
GPUIVF
>
(
index_
))
{
auto
host_index
=
device_index
->
Copy_index_gpu_to_cpu
();
index_
=
host_index
;
}
else
{
// TODO(linxj): LOG ERROR
}
}
void
IVFMixIndex
::
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
index_
=
std
::
make_shared
<
IVF
>
();
index_
->
Load
(
index_binary
);
dim
=
Dimension
();
}
}
}
}
cpp/src/wrapper/knowhere/vec_impl.h
浏览文件 @
c6ce5772
...
...
@@ -17,13 +17,15 @@ namespace engine {
class
VecIndexImpl
:
public
VecIndex
{
public:
explicit
VecIndexImpl
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
)
:
index_
(
std
::
move
(
index
))
{};
explicit
VecIndexImpl
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
,
const
IndexType
&
type
)
:
index_
(
std
::
move
(
index
)),
type
(
type
)
{};
void
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
,
const
long
&
nt
,
const
float
*
xt
)
override
;
IndexType
GetType
()
override
;
int64_t
Dimension
()
override
;
int64_t
Count
()
override
;
void
Add
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
)
override
;
...
...
@@ -33,21 +35,36 @@ class VecIndexImpl : public VecIndex {
protected:
int64_t
dim
=
0
;
IndexType
type
=
IndexType
::
INVALID
;
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index_
=
nullptr
;
};
class
IVFMixIndex
:
public
VecIndexImpl
{
public:
explicit
IVFMixIndex
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
)
:
VecIndexImpl
(
std
::
move
(
index
),
IndexType
::
FAISS_IVFFLAT_MIX
)
{};
void
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
,
const
long
&
nt
,
const
float
*
xt
)
override
;
void
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
override
;
};
class
BFIndex
:
public
VecIndexImpl
{
public:
explicit
BFIndex
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
)
:
VecIndexImpl
(
std
::
move
(
index
))
{};
void
Build
(
const
int64_t
&
d
);
float
*
GetRawVectors
();
explicit
BFIndex
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
)
:
VecIndexImpl
(
std
::
move
(
index
),
IndexType
::
FAISS_IDMAP
)
{};
void
Build
(
const
int64_t
&
d
);
float
*
GetRawVectors
();
void
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
,
const
long
&
nt
,
const
float
*
xt
)
override
;
int64_t
*
GetRawIds
();
int64_t
*
GetRawIds
();
};
}
...
...
cpp/src/wrapper/knowhere/vec_index.cpp
浏览文件 @
c6ce5772
...
...
@@ -16,7 +16,56 @@ namespace zilliz {
namespace
milvus
{
namespace
engine
{
// TODO(linxj): index_type => enum struct
struct
FileIOWriter
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOWriter
(
const
std
::
string
&
fname
);
~
FileIOWriter
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
};
struct
FileIOReader
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOReader
(
const
std
::
string
&
fname
);
~
FileIOReader
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
size_t
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
);
};
FileIOReader
::
FileIOReader
(
const
std
::
string
&
fname
)
{
name
=
fname
;
fs
=
std
::
fstream
(
name
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
}
FileIOReader
::~
FileIOReader
()
{
fs
.
close
();
}
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
read
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
}
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
)
{
return
0
;
}
FileIOWriter
::
FileIOWriter
(
const
std
::
string
&
fname
)
{
name
=
fname
;
fs
=
std
::
fstream
(
name
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
}
FileIOWriter
::~
FileIOWriter
()
{
fs
.
close
();
}
size_t
FileIOWriter
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
write
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
}
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
)
{
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
;
switch
(
type
)
{
...
...
@@ -32,6 +81,10 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
index
=
std
::
make_shared
<
zilliz
::
knowhere
::
GPUIVF
>
(
0
);
break
;
}
case
IndexType
::
FAISS_IVFFLAT_MIX
:
{
index
=
std
::
make_shared
<
zilliz
::
knowhere
::
GPUIVF
>
(
0
);
return
std
::
make_shared
<
IVFMixIndex
>
(
index
);
}
case
IndexType
::
FAISS_IVFPQ_CPU
:
{
index
=
std
::
make_shared
<
zilliz
::
knowhere
::
IVFPQ
>
();
break
;
...
...
@@ -44,15 +97,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
index
=
std
::
make_shared
<
zilliz
::
knowhere
::
CPUKDTRNG
>
();
break
;
}
//case IndexType::NSG: { // TODO(linxj): bug.
// index = std::make_shared<zilliz::knowhere::NSG>();
// break;
//}
//case IndexType::NSG: { // TODO(linxj): bug.
// index = std::make_shared<zilliz::knowhere::NSG>();
// break;
//}
default:
{
return
nullptr
;
}
}
return
std
::
make_shared
<
VecIndexImpl
>
(
index
);
return
std
::
make_shared
<
VecIndexImpl
>
(
index
,
type
);
}
VecIndexPtr
LoadVecIndex
(
const
IndexType
&
index_type
,
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
...
...
@@ -61,6 +114,64 @@ VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::Bi
return
index
;
}
VecIndexPtr
read_index
(
const
std
::
string
&
location
)
{
knowhere
::
BinarySet
load_data_list
;
FileIOReader
reader
(
location
);
reader
.
fs
.
seekg
(
0
,
reader
.
fs
.
end
);
size_t
length
=
reader
.
fs
.
tellg
();
reader
.
fs
.
seekg
(
0
);
size_t
rp
=
0
;
auto
current_type
=
IndexType
::
INVALID
;
reader
(
&
current_type
,
sizeof
(
current_type
));
rp
+=
sizeof
(
current_type
);
while
(
rp
<
length
)
{
size_t
meta_length
;
reader
(
&
meta_length
,
sizeof
(
meta_length
));
rp
+=
sizeof
(
meta_length
);
reader
.
fs
.
seekg
(
rp
);
auto
meta
=
new
char
[
meta_length
];
reader
(
meta
,
meta_length
);
rp
+=
meta_length
;
reader
.
fs
.
seekg
(
rp
);
size_t
bin_length
;
reader
(
&
bin_length
,
sizeof
(
bin_length
));
rp
+=
sizeof
(
bin_length
);
reader
.
fs
.
seekg
(
rp
);
auto
bin
=
new
uint8_t
[
bin_length
];
reader
(
bin
,
bin_length
);
rp
+=
bin_length
;
auto
binptr
=
std
::
make_shared
<
uint8_t
>
();
binptr
.
reset
(
bin
);
load_data_list
.
Append
(
std
::
string
(
meta
,
meta_length
),
binptr
,
bin_length
);
}
return
LoadVecIndex
(
current_type
,
load_data_list
);
}
void
write_index
(
VecIndexPtr
index
,
const
std
::
string
&
location
)
{
auto
binaryset
=
index
->
Serialize
();
auto
index_type
=
index
->
GetType
();
FileIOWriter
writer
(
location
);
writer
(
&
index_type
,
sizeof
(
IndexType
));
for
(
auto
&
iter
:
binaryset
.
binary_map_
)
{
auto
meta
=
iter
.
first
.
c_str
();
size_t
meta_length
=
iter
.
first
.
length
();
writer
(
&
meta_length
,
sizeof
(
meta_length
));
writer
((
void
*
)
meta
,
meta_length
);
auto
binary
=
iter
.
second
;
int64_t
binary_length
=
binary
->
size
;
writer
(
&
binary_length
,
sizeof
(
binary_length
));
writer
((
void
*
)
binary
->
data
.
get
(),
binary_length
);
}
}
}
}
}
cpp/src/wrapper/knowhere/vec_index.h
浏览文件 @
c6ce5772
...
...
@@ -20,6 +20,18 @@ namespace engine {
// TODO(linxj): jsoncons => rapidjson or other.
using
Config
=
zilliz
::
knowhere
::
Config
;
enum
class
IndexType
{
INVALID
=
0
,
FAISS_IDMAP
=
1
,
FAISS_IVFFLAT_CPU
,
FAISS_IVFFLAT_GPU
,
FAISS_IVFFLAT_MIX
,
// build on gpu and search on cpu
FAISS_IVFPQ_CPU
,
FAISS_IVFPQ_GPU
,
SPTAG_KDT_RNT_CPU
,
//NSG,
};
class
VecIndex
{
public:
virtual
void
BuildAll
(
const
long
&
nb
,
...
...
@@ -40,6 +52,8 @@ class VecIndex {
long
*
ids
,
const
Config
&
cfg
=
Config
())
=
0
;
virtual
IndexType
GetType
()
=
0
;
virtual
int64_t
Dimension
()
=
0
;
virtual
int64_t
Count
()
=
0
;
...
...
@@ -51,16 +65,9 @@ class VecIndex {
using
VecIndexPtr
=
std
::
shared_ptr
<
VecIndex
>
;
enum
class
IndexType
{
INVALID
=
0
,
FAISS_IDMAP
=
1
,
FAISS_IVFFLAT_CPU
,
FAISS_IVFFLAT_GPU
,
FAISS_IVFPQ_CPU
,
FAISS_IVFPQ_GPU
,
SPTAG_KDT_RNT_CPU
,
//NSG,
};
extern
void
write_index
(
VecIndexPtr
index
,
const
std
::
string
&
location
);
extern
VecIndexPtr
read_index
(
const
std
::
string
&
location
);
extern
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
);
...
...
knowhere
@
ca99a689
Subproject commit c
3123501d62f69f9eacaa73ee96c0daeb24620a5
Subproject commit c
a99a6899be4e8a0806452656cf0f2be19d79c1a
cpp/unittest/index_wrapper/knowhere_test.cpp
浏览文件 @
c6ce5772
...
...
@@ -28,11 +28,37 @@ class KnowhereWrapperTest
//auto generator = GetGenerateFactory(generator_type);
auto
generator
=
std
::
make_shared
<
DataGenBase
>
();
generator
->
GenData
(
dim
,
nb
,
nq
,
xb
,
xq
,
ids
,
k
,
gt_ids
);
generator
->
GenData
(
dim
,
nb
,
nq
,
xb
,
xq
,
ids
,
k
,
gt_ids
,
gt_dis
);
index_
=
GetVecIndexFactory
(
index_type
);
}
void
AssertResult
(
const
std
::
vector
<
long
>
&
ids
,
const
std
::
vector
<
float
>
&
dis
)
{
EXPECT_EQ
(
ids
.
size
(),
nq
*
k
);
EXPECT_EQ
(
dis
.
size
(),
nq
*
k
);
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
EXPECT_EQ
(
ids
[
i
*
k
],
gt_ids
[
i
*
k
]);
EXPECT_EQ
(
dis
[
i
*
k
],
gt_dis
[
i
*
k
]);
}
int
match
=
0
;
for
(
int
i
=
0
;
i
<
nq
;
++
i
)
{
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
for
(
int
l
=
0
;
l
<
k
;
++
l
)
{
if
(
ids
[
i
*
nq
+
j
]
==
gt_ids
[
i
*
nq
+
l
])
match
++
;
}
}
}
auto
precision
=
float
(
match
)
/
(
nq
*
k
);
EXPECT_GT
(
precision
,
0.5
);
std
::
cout
<<
std
::
endl
<<
"Precision: "
<<
precision
<<
", match: "
<<
match
<<
", total: "
<<
nq
*
k
<<
std
::
endl
;
}
protected:
IndexType
index_type
;
Config
train_cfg
;
...
...
@@ -50,126 +76,88 @@ class KnowhereWrapperTest
// Ground Truth
std
::
vector
<
long
>
gt_ids
;
std
::
vector
<
float
>
gt_dis
;
};
INSTANTIATE_TEST_CASE_P
(
WrapperParam
,
KnowhereWrapperTest
,
Values
(
//
["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
std
::
make_tuple
(
IndexType
::
FAISS_IVFFLAT_CPU
,
"Default"
,
64
,
10000
,
10
,
10
,
64
,
10000
0
,
10
,
10
,
Config
::
object
{{
"nlist"
,
100
},
{
"dim"
,
64
}},
Config
::
object
{{
"dim"
,
64
},
{
"k"
,
10
},
{
"nprobe"
,
2
0
}}
Config
::
object
{{
"dim"
,
64
},
{
"k"
,
10
},
{
"nprobe"
,
1
0
}}
),
std
::
make_tuple
(
IndexType
::
SPTAG_KDT_RNT_CPU
,
"Default"
,
64
,
10000
,
10
,
10
,
Config
::
object
{{
"TPTNumber"
,
1
},
{
"dim"
,
64
}},
//std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default",
// 64, 10000, 10, 10,
// Config::object{{"nlist", 100}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 40}}
//),
std
::
make_tuple
(
IndexType
::
FAISS_IVFFLAT_MIX
,
"Default"
,
64
,
100000
,
10
,
10
,
Config
::
object
{{
"nlist"
,
100
},
{
"dim"
,
64
}},
Config
::
object
{{
"dim"
,
64
},
{
"k"
,
10
},
{
"nprobe"
,
10
}}
),
std
::
make_tuple
(
IndexType
::
FAISS_IDMAP
,
"Default"
,
64
,
100000
,
10
,
10
,
Config
::
object
{{
"dim"
,
64
}},
Config
::
object
{{
"dim"
,
64
},
{
"k"
,
10
}}
)
//std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",
// 64, 10000, 10, 10,
// Config::object{{"TPTNumber", 1}, {"dim", 64}},
// Config::object{{"dim", 64}, {"k", 10}}
//)
)
);
void
AssertAnns
(
const
std
::
vector
<
long
>
&
gt
,
const
std
::
vector
<
long
>
&
res
,
const
int
&
nq
,
const
int
&
k
)
{
EXPECT_EQ
(
res
.
size
(),
nq
*
k
);
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
EXPECT_EQ
(
gt
[
i
*
k
],
res
[
i
*
k
]);
}
int
match
=
0
;
for
(
int
i
=
0
;
i
<
nq
;
++
i
)
{
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
for
(
int
l
=
0
;
l
<
k
;
++
l
)
{
if
(
gt
[
i
*
nq
+
j
]
==
res
[
i
*
nq
+
l
])
match
++
;
}
}
}
// TODO(linxj): percision check
EXPECT_GT
(
float
(
match
/
nq
*
k
),
0.5
);
}
TEST_P
(
KnowhereWrapperTest
,
base_test
)
{
std
::
vector
<
long
>
res_ids
;
float
*
D
=
new
float
[
k
*
nq
];
res_ids
.
resize
(
nq
*
k
);
EXPECT_EQ
(
index_
->
GetType
(),
index_type
);
auto
elems
=
nq
*
k
;
std
::
vector
<
int64_t
>
res_ids
(
elems
);
std
::
vector
<
float
>
res_dis
(
elems
);
index_
->
BuildAll
(
nb
,
xb
.
data
(),
ids
.
data
(),
train_cfg
);
index_
->
Search
(
nq
,
xq
.
data
(),
D
,
res_ids
.
data
(),
search_cfg
);
AssertAnns
(
gt_ids
,
res_ids
,
nq
,
k
);
delete
[]
D
;
index_
->
Search
(
nq
,
xq
.
data
(),
res_dis
.
data
(),
res_ids
.
data
(),
search_cfg
);
AssertResult
(
res_ids
,
res_dis
);
}
TEST_P
(
KnowhereWrapperTest
,
serialize_test
)
{
std
::
vector
<
long
>
res_ids
;
float
*
D
=
new
float
[
k
*
nq
];
res_ids
.
resize
(
nq
*
k
);
TEST_P
(
KnowhereWrapperTest
,
serialize
)
{
EXPECT_EQ
(
index_
->
GetType
(),
index_type
);
auto
elems
=
nq
*
k
;
std
::
vector
<
int64_t
>
res_ids
(
elems
);
std
::
vector
<
float
>
res_dis
(
elems
);
index_
->
BuildAll
(
nb
,
xb
.
data
(),
ids
.
data
(),
train_cfg
);
index_
->
Search
(
nq
,
xq
.
data
(),
D
,
res_ids
.
data
(),
search_cfg
);
Assert
Anns
(
gt_ids
,
res_ids
,
nq
,
k
);
index_
->
Search
(
nq
,
xq
.
data
(),
res_dis
.
data
()
,
res_ids
.
data
(),
search_cfg
);
Assert
Result
(
res_ids
,
res_dis
);
{
auto
binaryset
=
index_
->
Serialize
();
//int fileno = 0;
//const std::string &base_name = "/tmp/wrapper_serialize_test_bin_";
//std::vector<std::string> filename_list;
//std::vector<std::pair<std::string, size_t >> meta_list;
//for (auto &iter: binaryset.binary_map_) {
// const std::string &filename = base_name + std::to_string(fileno);
// FileIOWriter writer(filename);
// writer(iter.second->data.get(), iter.second->size);
//
// meta_list.push_back(std::make_pair(iter.first, iter.second.size));
// filename_list.push_back(filename);
// ++fileno;
//}
//
//BinarySet load_data_list;
//for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
// auto bin_size = meta_list[i].second;
// FileIOReader reader(filename_list[i]);
// std::vector<uint8_t> load_data(bin_size);
// reader(load_data.data(), bin_size);
// load_data_list.Append(meta_list[i].first, load_data);
//}
int
fileno
=
0
;
std
::
vector
<
std
::
string
>
filename_list
;
const
std
::
string
&
base_name
=
"/tmp/wrapper_serialize_test_bin_"
;
std
::
vector
<
std
::
pair
<
std
::
string
,
size_t
>>
meta_list
;
for
(
auto
&
iter
:
binaryset
.
binary_map_
)
{
const
std
::
string
&
filename
=
base_name
+
std
::
to_string
(
fileno
);
FileIOWriter
writer
(
filename
);
writer
(
iter
.
second
->
data
.
get
(),
iter
.
second
->
size
);
meta_list
.
emplace_back
(
std
::
make_pair
(
iter
.
first
,
iter
.
second
->
size
));
filename_list
.
push_back
(
filename
);
++
fileno
;
}
BinarySet
load_data_list
;
for
(
int
i
=
0
;
i
<
filename_list
.
size
()
&&
i
<
meta_list
.
size
();
++
i
)
{
auto
bin_size
=
meta_list
[
i
].
second
;
FileIOReader
reader
(
filename_list
[
i
]);
auto
load_data
=
new
uint8_t
[
bin_size
];
reader
(
load_data
,
bin_size
);
auto
data
=
std
::
make_shared
<
uint8_t
>
();
data
.
reset
(
load_data
);
load_data_list
.
Append
(
meta_list
[
i
].
first
,
data
,
bin_size
);
}
res_ids
.
clear
();
res_ids
.
resize
(
nq
*
k
);
auto
new_index
=
GetVecIndexFactory
(
index_type
);
new_index
->
Load
(
load_data_list
);
new_index
->
Search
(
nq
,
xq
.
data
(),
D
,
res_ids
.
data
(),
search_cfg
);
AssertAnns
(
gt_ids
,
res_ids
,
nq
,
k
);
auto
binary
=
index_
->
Serialize
();
auto
type
=
index_
->
GetType
();
auto
new_index
=
GetVecIndexFactory
(
type
);
new_index
->
Load
(
binary
);
EXPECT_EQ
(
new_index
->
Dimension
(),
index_
->
Dimension
());
EXPECT_EQ
(
new_index
->
Count
(),
index_
->
Count
());
std
::
vector
<
int64_t
>
res_ids
(
elems
);
std
::
vector
<
float
>
res_dis
(
elems
);
new_index
->
Search
(
nq
,
xq
.
data
(),
res_dis
.
data
(),
res_ids
.
data
(),
search_cfg
);
AssertResult
(
res_ids
,
res_dis
);
}
delete
[]
D
;
{
std
::
string
file_location
=
"/tmp/whatever"
;
write_index
(
index_
,
file_location
);
auto
new_index
=
read_index
(
file_location
);
EXPECT_EQ
(
new_index
->
GetType
(),
index_type
);
EXPECT_EQ
(
new_index
->
Dimension
(),
index_
->
Dimension
());
EXPECT_EQ
(
new_index
->
Count
(),
index_
->
Count
());
std
::
vector
<
int64_t
>
res_ids
(
elems
);
std
::
vector
<
float
>
res_dis
(
elems
);
new_index
->
Search
(
nq
,
xq
.
data
(),
res_dis
.
data
(),
res_ids
.
data
(),
search_cfg
);
AssertResult
(
res_ids
,
res_dis
);
}
}
cpp/unittest/index_wrapper/utils.cpp
浏览文件 @
c6ce5772
...
...
@@ -19,7 +19,7 @@ DataGenPtr GetGenerateFactory(const std::string &gen_type) {
void
DataGenBase
::
GenData
(
const
int
&
dim
,
const
int
&
nb
,
const
int
&
nq
,
float
*
xb
,
float
*
xq
,
long
*
ids
,
const
int
&
k
,
long
*
gt_ids
)
{
const
int
&
k
,
long
*
gt_ids
,
float
*
gt_dis
)
{
for
(
auto
i
=
0
;
i
<
nb
;
++
i
)
{
for
(
auto
j
=
0
;
j
<
dim
;
++
j
)
{
//p_data[i * d + j] = float(base + i);
...
...
@@ -35,8 +35,7 @@ void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
faiss
::
IndexFlatL2
index
(
dim
);
//index.add_with_ids(nb, xb, ids);
index
.
add
(
nb
,
xb
);
float
*
D
=
new
float
[
k
*
nq
];
index
.
search
(
nq
,
xq
,
k
,
D
,
gt_ids
);
index
.
search
(
nq
,
xq
,
k
,
gt_dis
,
gt_ids
);
}
void
DataGenBase
::
GenData
(
const
int
&
dim
,
...
...
@@ -46,36 +45,12 @@ void DataGenBase::GenData(const int &dim,
std
::
vector
<
float
>
&
xq
,
std
::
vector
<
long
>
&
ids
,
const
int
&
k
,
std
::
vector
<
long
>
&
gt_ids
)
{
std
::
vector
<
long
>
&
gt_ids
,
std
::
vector
<
float
>
&
gt_dis
)
{
xb
.
resize
(
nb
*
dim
);
xq
.
resize
(
nq
*
dim
);
ids
.
resize
(
nb
);
gt_ids
.
resize
(
nq
*
k
);
GenData
(
dim
,
nb
,
nq
,
xb
.
data
(),
xq
.
data
(),
ids
.
data
(),
k
,
gt_ids
.
data
());
}
FileIOReader
::
FileIOReader
(
const
std
::
string
&
fname
)
{
name
=
fname
;
fs
=
std
::
fstream
(
name
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
}
FileIOReader
::~
FileIOReader
()
{
fs
.
close
();
}
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
read
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
}
FileIOWriter
::
FileIOWriter
(
const
std
::
string
&
fname
)
{
name
=
fname
;
fs
=
std
::
fstream
(
name
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
}
FileIOWriter
::~
FileIOWriter
()
{
fs
.
close
();
}
size_t
FileIOWriter
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
write
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
gt_dis
.
resize
(
nq
*
k
);
GenData
(
dim
,
nb
,
nq
,
xb
.
data
(),
xq
.
data
(),
ids
.
data
(),
k
,
gt_ids
.
data
(),
gt_dis
.
data
());
}
cpp/unittest/index_wrapper/utils.h
浏览文件 @
c6ce5772
...
...
@@ -23,7 +23,7 @@ extern DataGenPtr GetGenerateFactory(const std::string &gen_type);
class
DataGenBase
{
public:
virtual
void
GenData
(
const
int
&
dim
,
const
int
&
nb
,
const
int
&
nq
,
float
*
xb
,
float
*
xq
,
long
*
ids
,
const
int
&
k
,
long
*
gt_ids
);
const
int
&
k
,
long
*
gt_ids
,
float
*
gt_dis
);
virtual
void
GenData
(
const
int
&
dim
,
const
int
&
nb
,
...
...
@@ -32,30 +32,14 @@ class DataGenBase {
std
::
vector
<
float
>
&
xq
,
std
::
vector
<
long
>
&
ids
,
const
int
&
k
,
std
::
vector
<
long
>
&
gt_ids
);
std
::
vector
<
long
>
&
gt_ids
,
std
::
vector
<
float
>
&
gt_dis
);
};
class
SanityCheck
:
public
DataGenBase
{
public:
void
GenData
(
const
int
&
dim
,
const
int
&
nb
,
const
int
&
nq
,
float
*
xb
,
float
*
xq
,
long
*
ids
,
const
int
&
k
,
long
*
gt_id
s
)
override
;
};
//
class SanityCheck : public DataGenBase {
//
public:
//
void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
// const int &k, long *gt_ids, float *gt_di
s) override;
//
};
struct
FileIOWriter
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOWriter
(
const
std
::
string
&
fname
);
~
FileIOWriter
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
};
struct
FileIOReader
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOReader
(
const
std
::
string
&
fname
);
~
FileIOReader
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
};
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录