Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
b2003e9f
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b2003e9f
编写于
9月 20, 2019
作者:
H
Heisenberg
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MS-583 Change to Status from errorcode
Former-commit-id: eaef6a33657b660e4e4c832a56b6982f74e49ec2
上级
c7ec7ced
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
295 addition
and
209 deletion
+295
-209
cpp/src/db/engine/ExecutionEngineImpl.cpp
cpp/src/db/engine/ExecutionEngineImpl.cpp
+12
-20
cpp/src/wrapper/KnowhereResource.cpp
cpp/src/wrapper/KnowhereResource.cpp
+6
-4
cpp/src/wrapper/KnowhereResource.h
cpp/src/wrapper/KnowhereResource.h
+6
-3
cpp/src/wrapper/vec_impl.cpp
cpp/src/wrapper/vec_impl.cpp
+75
-58
cpp/src/wrapper/vec_impl.h
cpp/src/wrapper/vec_impl.h
+70
-34
cpp/src/wrapper/vec_index.cpp
cpp/src/wrapper/vec_index.cpp
+45
-28
cpp/src/wrapper/vec_index.h
cpp/src/wrapper/vec_index.h
+67
-46
cpp/unittest/scheduler/scheduler_test.cpp
cpp/unittest/scheduler/scheduler_test.cpp
+5
-7
cpp/unittest/server/cache_test.cpp
cpp/unittest/server/cache_test.cpp
+8
-9
cpp/unittest/wrapper/CMakeLists.txt
cpp/unittest/wrapper/CMakeLists.txt
+1
-0
未找到文件。
cpp/src/db/engine/ExecutionEngineImpl.cpp
浏览文件 @
b2003e9f
...
@@ -99,11 +99,8 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
...
@@ -99,11 +99,8 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
}
}
Status
ExecutionEngineImpl
::
AddWithIds
(
long
n
,
const
float
*
xdata
,
const
long
*
xids
)
{
Status
ExecutionEngineImpl
::
AddWithIds
(
long
n
,
const
float
*
xdata
,
const
long
*
xids
)
{
auto
ec
=
index_
->
Add
(
n
,
xdata
,
xids
);
auto
status
=
index_
->
Add
(
n
,
xdata
,
xids
);
if
(
ec
!=
KNOWHERE_SUCCESS
)
{
return
status
;
return
Status
(
DB_ERROR
,
"Add error"
);
}
return
Status
::
OK
();
}
}
size_t
ExecutionEngineImpl
::
Count
()
const
{
size_t
ExecutionEngineImpl
::
Count
()
const
{
...
@@ -131,11 +128,8 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
...
@@ -131,11 +128,8 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
}
}
Status
ExecutionEngineImpl
::
Serialize
()
{
Status
ExecutionEngineImpl
::
Serialize
()
{
auto
ec
=
write_index
(
index_
,
location_
);
auto
status
=
write_index
(
index_
,
location_
);
if
(
ec
!=
KNOWHERE_SUCCESS
)
{
return
status
;
return
Status
(
DB_ERROR
,
"Serialize: write to disk error"
);
}
return
Status
::
OK
();
}
}
Status
ExecutionEngineImpl
::
Load
(
bool
to_cache
)
{
Status
ExecutionEngineImpl
::
Load
(
bool
to_cache
)
{
...
@@ -254,12 +248,11 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
...
@@ -254,12 +248,11 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
}
}
if
(
auto
file_index
=
std
::
dynamic_pointer_cast
<
BFIndex
>
(
to_merge
))
{
if
(
auto
file_index
=
std
::
dynamic_pointer_cast
<
BFIndex
>
(
to_merge
))
{
auto
ec
=
index_
->
Add
(
file_index
->
Count
(),
file_index
->
GetRawVectors
(),
file_index
->
GetRawIds
());
auto
status
=
index_
->
Add
(
file_index
->
Count
(),
file_index
->
GetRawVectors
(),
file_index
->
GetRawIds
());
if
(
ec
!=
KNOWHERE_SUCCESS
)
{
if
(
!
status
.
ok
()
)
{
ENGINE_LOG_ERROR
<<
"Merge: Add Error"
;
ENGINE_LOG_ERROR
<<
"Merge: Add Error"
;
return
Status
(
DB_ERROR
,
"Merge: Add Error"
);
}
}
return
Status
::
OK
()
;
return
status
;
}
else
{
}
else
{
return
Status
(
DB_ERROR
,
"file index type is not idmap"
);
return
Status
(
DB_ERROR
,
"file index type is not idmap"
);
}
}
...
@@ -287,11 +280,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t
...
@@ -287,11 +280,11 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t
build_cfg
[
"nlist"
]
=
nlist_
;
build_cfg
[
"nlist"
]
=
nlist_
;
AutoGenParams
(
to_index
->
GetType
(),
Count
(),
build_cfg
);
AutoGenParams
(
to_index
->
GetType
(),
Count
(),
build_cfg
);
auto
ec
=
to_index
->
BuildAll
(
Count
(),
auto
status
=
to_index
->
BuildAll
(
Count
(),
from_index
->
GetRawVectors
(),
from_index
->
GetRawVectors
(),
from_index
->
GetRawIds
(),
from_index
->
GetRawIds
(),
build_cfg
);
build_cfg
);
if
(
ec
!=
KNOWHERE_SUCCESS
)
{
throw
Exception
(
DB_ERROR
,
"Build index error"
);
}
if
(
!
status
.
ok
())
{
throw
Exception
(
DB_ERROR
,
status
.
message
()
);
}
return
std
::
make_shared
<
ExecutionEngineImpl
>
(
to_index
,
location
,
engine_type
,
metric_type_
,
nlist_
);
return
std
::
make_shared
<
ExecutionEngineImpl
>
(
to_index
,
location
,
engine_type
,
metric_type_
,
nlist_
);
}
}
...
@@ -309,12 +302,11 @@ Status ExecutionEngineImpl::Search(long n,
...
@@ -309,12 +302,11 @@ Status ExecutionEngineImpl::Search(long n,
ENGINE_LOG_DEBUG
<<
"Search Params: [k] "
<<
k
<<
" [nprobe] "
<<
nprobe
;
ENGINE_LOG_DEBUG
<<
"Search Params: [k] "
<<
k
<<
" [nprobe] "
<<
nprobe
;
auto
cfg
=
Config
::
object
{{
"k"
,
k
},
{
"nprobe"
,
nprobe
}};
auto
cfg
=
Config
::
object
{{
"k"
,
k
},
{
"nprobe"
,
nprobe
}};
auto
ec
=
index_
->
Search
(
n
,
data
,
distances
,
labels
,
cfg
);
auto
status
=
index_
->
Search
(
n
,
data
,
distances
,
labels
,
cfg
);
if
(
ec
!=
KNOWHERE_SUCCESS
)
{
if
(
!
status
.
ok
()
)
{
ENGINE_LOG_ERROR
<<
"Search error"
;
ENGINE_LOG_ERROR
<<
"Search error"
;
return
Status
(
DB_ERROR
,
"Search: Search Error"
);
}
}
return
Status
::
OK
()
;
return
status
;
}
}
Status
ExecutionEngineImpl
::
Cache
()
{
Status
ExecutionEngineImpl
::
Cache
()
{
...
...
cpp/src/wrapper/KnowhereResource.cpp
浏览文件 @
b2003e9f
...
@@ -28,7 +28,8 @@ namespace engine {
...
@@ -28,7 +28,8 @@ namespace engine {
constexpr
int64_t
M_BYTE
=
1024
*
1024
;
constexpr
int64_t
M_BYTE
=
1024
*
1024
;
ErrorCode
KnowhereResource
::
Initialize
()
{
Status
KnowhereResource
::
Initialize
()
{
struct
GpuResourceSetting
{
struct
GpuResourceSetting
{
int64_t
pinned_memory
=
300
*
M_BYTE
;
int64_t
pinned_memory
=
300
*
M_BYTE
;
int64_t
temp_memory
=
300
*
M_BYTE
;
int64_t
temp_memory
=
300
*
M_BYTE
;
...
@@ -65,12 +66,13 @@ ErrorCode KnowhereResource::Initialize() {
...
@@ -65,12 +66,13 @@ ErrorCode KnowhereResource::Initialize() {
iter
->
second
.
resource_num
);
iter
->
second
.
resource_num
);
}
}
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
ErrorCode
KnowhereResource
::
Finalize
()
{
Status
KnowhereResource
::
Finalize
()
{
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
Free
();
// free gpu resource.
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
Free
();
// free gpu resource.
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
}
}
...
...
cpp/src/wrapper/KnowhereResource.h
浏览文件 @
b2003e9f
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#pragma once
#pragma once
#include "utils/
Error
.h"
#include "utils/
Status
.h"
namespace
zilliz
{
namespace
zilliz
{
namespace
milvus
{
namespace
milvus
{
...
@@ -26,8 +26,11 @@ namespace engine {
...
@@ -26,8 +26,11 @@ namespace engine {
class
KnowhereResource
{
class
KnowhereResource
{
public:
public:
static
ErrorCode
Initialize
();
static
Status
static
ErrorCode
Finalize
();
Initialize
();
static
Status
Finalize
();
};
};
...
...
cpp/src/wrapper/vec_impl.cpp
浏览文件 @
b2003e9f
...
@@ -21,7 +21,6 @@
...
@@ -21,7 +21,6 @@
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#include "vec_impl.h"
#include "vec_impl.h"
#include "data_transfer.h"
#include "data_transfer.h"
...
@@ -32,12 +31,13 @@ namespace engine {
...
@@ -32,12 +31,13 @@ namespace engine {
using
namespace
zilliz
::
knowhere
;
using
namespace
zilliz
::
knowhere
;
ErrorCode
VecIndexImpl
::
BuildAll
(
const
long
&
nb
,
Status
const
float
*
xb
,
VecIndexImpl
::
BuildAll
(
const
long
&
nb
,
const
long
*
ids
,
const
float
*
xb
,
const
Config
&
cfg
,
const
long
*
ids
,
const
long
&
nt
,
const
Config
&
cfg
,
const
float
*
xt
)
{
const
long
&
nt
,
const
float
*
xt
)
{
try
{
try
{
dim
=
cfg
[
"dim"
].
as
<
int
>
();
dim
=
cfg
[
"dim"
].
as
<
int
>
();
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
...
@@ -49,36 +49,38 @@ ErrorCode VecIndexImpl::BuildAll(const long &nb,
...
@@ -49,36 +49,38 @@ ErrorCode VecIndexImpl::BuildAll(const long &nb,
index_
->
Add
(
dataset
,
cfg
);
index_
->
Add
(
dataset
,
cfg
);
}
catch
(
KnowhereException
&
e
)
{
}
catch
(
KnowhereException
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_UNEXPECTED_ERROR
;
return
Status
(
KNOWHERE_UNEXPECTED_ERROR
,
e
.
what
())
;
}
catch
(
jsoncons
::
json_exception
&
e
)
{
}
catch
(
jsoncons
::
json_exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_INVALID_ARGUMENT
;
return
Status
(
KNOWHERE_INVALID_ARGUMENT
,
e
.
what
())
;
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_ERROR
;
return
Status
(
KNOWHERE_ERROR
,
e
.
what
())
;
}
}
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
ErrorCode
VecIndexImpl
::
Add
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
)
{
Status
VecIndexImpl
::
Add
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
)
{
try
{
try
{
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
index_
->
Add
(
dataset
,
cfg
);
index_
->
Add
(
dataset
,
cfg
);
}
catch
(
KnowhereException
&
e
)
{
}
catch
(
KnowhereException
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_UNEXPECTED_ERROR
;
return
Status
(
KNOWHERE_UNEXPECTED_ERROR
,
e
.
what
())
;
}
catch
(
jsoncons
::
json_exception
&
e
)
{
}
catch
(
jsoncons
::
json_exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_INVALID_ARGUMENT
;
return
Status
(
KNOWHERE_INVALID_ARGUMENT
,
e
.
what
())
;
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_ERROR
;
return
Status
(
KNOWHERE_ERROR
,
e
.
what
())
;
}
}
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
ErrorCode
VecIndexImpl
::
Search
(
const
long
&
nq
,
const
float
*
xq
,
float
*
dist
,
long
*
ids
,
const
Config
&
cfg
)
{
Status
VecIndexImpl
::
Search
(
const
long
&
nq
,
const
float
*
xq
,
float
*
dist
,
long
*
ids
,
const
Config
&
cfg
)
{
try
{
try
{
auto
k
=
cfg
[
"k"
].
as
<
int
>
();
auto
k
=
cfg
[
"k"
].
as
<
int
>
();
auto
dataset
=
GenDataset
(
nq
,
dim
,
xq
);
auto
dataset
=
GenDataset
(
nq
,
dim
,
xq
);
...
@@ -117,41 +119,47 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon
...
@@ -117,41 +119,47 @@ ErrorCode VecIndexImpl::Search(const long &nq, const float *xq, float *dist, lon
}
catch
(
KnowhereException
&
e
)
{
}
catch
(
KnowhereException
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_UNEXPECTED_ERROR
;
return
Status
(
KNOWHERE_UNEXPECTED_ERROR
,
e
.
what
())
;
}
catch
(
jsoncons
::
json_exception
&
e
)
{
}
catch
(
jsoncons
::
json_exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_INVALID_ARGUMENT
;
return
Status
(
KNOWHERE_INVALID_ARGUMENT
,
e
.
what
())
;
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_ERROR
;
return
Status
(
KNOWHERE_ERROR
,
e
.
what
())
;
}
}
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
zilliz
::
knowhere
::
BinarySet
VecIndexImpl
::
Serialize
()
{
zilliz
::
knowhere
::
BinarySet
VecIndexImpl
::
Serialize
()
{
type
=
ConvertToCpuIndexType
(
type
);
type
=
ConvertToCpuIndexType
(
type
);
return
index_
->
Serialize
();
return
index_
->
Serialize
();
}
}
ErrorCode
VecIndexImpl
::
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
Status
VecIndexImpl
::
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
index_
->
Load
(
index_binary
);
index_
->
Load
(
index_binary
);
dim
=
Dimension
();
dim
=
Dimension
();
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
int64_t
VecIndexImpl
::
Dimension
()
{
int64_t
VecIndexImpl
::
Dimension
()
{
return
index_
->
Dimension
();
return
index_
->
Dimension
();
}
}
int64_t
VecIndexImpl
::
Count
()
{
int64_t
VecIndexImpl
::
Count
()
{
return
index_
->
Count
();
return
index_
->
Count
();
}
}
IndexType
VecIndexImpl
::
GetType
()
{
IndexType
VecIndexImpl
::
GetType
()
{
return
type
;
return
type
;
}
}
VecIndexPtr
VecIndexImpl
::
CopyToGpu
(
const
int64_t
&
device_id
,
const
Config
&
cfg
)
{
VecIndexPtr
VecIndexImpl
::
CopyToGpu
(
const
int64_t
&
device_id
,
const
Config
&
cfg
)
{
// TODO(linxj): exception handle
// TODO(linxj): exception handle
auto
gpu_index
=
zilliz
::
knowhere
::
cloner
::
CopyCpuToGpu
(
index_
,
device_id
,
cfg
);
auto
gpu_index
=
zilliz
::
knowhere
::
cloner
::
CopyCpuToGpu
(
index_
,
device_id
,
cfg
);
auto
new_index
=
std
::
make_shared
<
VecIndexImpl
>
(
gpu_index
,
ConvertToGpuIndexType
(
type
));
auto
new_index
=
std
::
make_shared
<
VecIndexImpl
>
(
gpu_index
,
ConvertToGpuIndexType
(
type
));
...
@@ -159,7 +167,8 @@ VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg)
...
@@ -159,7 +167,8 @@ VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg)
return
new_index
;
return
new_index
;
}
}
VecIndexPtr
VecIndexImpl
::
CopyToCpu
(
const
Config
&
cfg
)
{
VecIndexPtr
VecIndexImpl
::
CopyToCpu
(
const
Config
&
cfg
)
{
// TODO(linxj): exception handle
// TODO(linxj): exception handle
auto
cpu_index
=
zilliz
::
knowhere
::
cloner
::
CopyGpuToCpu
(
index_
,
cfg
);
auto
cpu_index
=
zilliz
::
knowhere
::
cloner
::
CopyGpuToCpu
(
index_
,
cfg
);
auto
new_index
=
std
::
make_shared
<
VecIndexImpl
>
(
cpu_index
,
ConvertToCpuIndexType
(
type
));
auto
new_index
=
std
::
make_shared
<
VecIndexImpl
>
(
cpu_index
,
ConvertToCpuIndexType
(
type
));
...
@@ -167,32 +176,37 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
...
@@ -167,32 +176,37 @@ VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) {
return
new_index
;
return
new_index
;
}
}
VecIndexPtr
VecIndexImpl
::
Clone
()
{
VecIndexPtr
VecIndexImpl
::
Clone
()
{
// TODO(linxj): exception handle
// TODO(linxj): exception handle
auto
clone_index
=
std
::
make_shared
<
VecIndexImpl
>
(
index_
->
Clone
(),
type
);
auto
clone_index
=
std
::
make_shared
<
VecIndexImpl
>
(
index_
->
Clone
(),
type
);
clone_index
->
dim
=
dim
;
clone_index
->
dim
=
dim
;
return
clone_index
;
return
clone_index
;
}
}
int64_t
VecIndexImpl
::
GetDeviceId
()
{
int64_t
if
(
auto
device_idx
=
std
::
dynamic_pointer_cast
<
GPUIndex
>
(
index_
)){
VecIndexImpl
::
GetDeviceId
()
{
if
(
auto
device_idx
=
std
::
dynamic_pointer_cast
<
GPUIndex
>
(
index_
))
{
return
device_idx
->
GetGpuDevice
();
return
device_idx
->
GetGpuDevice
();
}
}
// else
// else
return
-
1
;
// -1 == cpu
return
-
1
;
// -1 == cpu
}
}
float
*
BFIndex
::
GetRawVectors
()
{
float
*
BFIndex
::
GetRawVectors
()
{
auto
raw_index
=
std
::
dynamic_pointer_cast
<
IDMAP
>
(
index_
);
auto
raw_index
=
std
::
dynamic_pointer_cast
<
IDMAP
>
(
index_
);
if
(
raw_index
)
{
return
raw_index
->
GetRawVectors
();
}
if
(
raw_index
)
{
return
raw_index
->
GetRawVectors
();
}
return
nullptr
;
return
nullptr
;
}
}
int64_t
*
BFIndex
::
GetRawIds
()
{
int64_t
*
BFIndex
::
GetRawIds
()
{
return
std
::
static_pointer_cast
<
IDMAP
>
(
index_
)
->
GetRawIds
();
return
std
::
static_pointer_cast
<
IDMAP
>
(
index_
)
->
GetRawIds
();
}
}
ErrorCode
BFIndex
::
Build
(
const
Config
&
cfg
)
{
ErrorCode
BFIndex
::
Build
(
const
Config
&
cfg
)
{
try
{
try
{
dim
=
cfg
[
"dim"
].
as
<
int
>
();
dim
=
cfg
[
"dim"
].
as
<
int
>
();
std
::
static_pointer_cast
<
IDMAP
>
(
index_
)
->
Train
(
cfg
);
std
::
static_pointer_cast
<
IDMAP
>
(
index_
)
->
Train
(
cfg
);
...
@@ -209,12 +223,13 @@ ErrorCode BFIndex::Build(const Config &cfg) {
...
@@ -209,12 +223,13 @@ ErrorCode BFIndex::Build(const Config &cfg) {
return
KNOWHERE_SUCCESS
;
return
KNOWHERE_SUCCESS
;
}
}
ErrorCode
BFIndex
::
BuildAll
(
const
long
&
nb
,
Status
const
float
*
xb
,
BFIndex
::
BuildAll
(
const
long
&
nb
,
const
long
*
ids
,
const
float
*
xb
,
const
Config
&
cfg
,
const
long
*
ids
,
const
long
&
nt
,
const
Config
&
cfg
,
const
float
*
xt
)
{
const
long
&
nt
,
const
float
*
xt
)
{
try
{
try
{
dim
=
cfg
[
"dim"
].
as
<
int
>
();
dim
=
cfg
[
"dim"
].
as
<
int
>
();
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
...
@@ -223,24 +238,25 @@ ErrorCode BFIndex::BuildAll(const long &nb,
...
@@ -223,24 +238,25 @@ ErrorCode BFIndex::BuildAll(const long &nb,
index_
->
Add
(
dataset
,
cfg
);
index_
->
Add
(
dataset
,
cfg
);
}
catch
(
KnowhereException
&
e
)
{
}
catch
(
KnowhereException
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_UNEXPECTED_ERROR
;
return
Status
(
KNOWHERE_UNEXPECTED_ERROR
,
e
.
what
())
;
}
catch
(
jsoncons
::
json_exception
&
e
)
{
}
catch
(
jsoncons
::
json_exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_INVALID_ARGUMENT
;
return
Status
(
KNOWHERE_INVALID_ARGUMENT
,
e
.
what
())
;
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_ERROR
;
return
Status
(
KNOWHERE_ERROR
,
e
.
what
())
;
}
}
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
// TODO(linxj): add lock here.
// TODO(linxj): add lock here.
ErrorCode
IVFMixIndex
::
BuildAll
(
const
long
&
nb
,
Status
const
float
*
xb
,
IVFMixIndex
::
BuildAll
(
const
long
&
nb
,
const
long
*
ids
,
const
float
*
xb
,
const
Config
&
cfg
,
const
long
*
ids
,
const
long
&
nt
,
const
Config
&
cfg
,
const
float
*
xt
)
{
const
long
&
nt
,
const
float
*
xt
)
{
try
{
try
{
dim
=
cfg
[
"dim"
].
as
<
int
>
();
dim
=
cfg
[
"dim"
].
as
<
int
>
();
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
auto
dataset
=
GenDatasetWithIds
(
nb
,
dim
,
xb
,
ids
);
...
@@ -257,26 +273,27 @@ ErrorCode IVFMixIndex::BuildAll(const long &nb,
...
@@ -257,26 +273,27 @@ ErrorCode IVFMixIndex::BuildAll(const long &nb,
type
=
ConvertToCpuIndexType
(
type
);
type
=
ConvertToCpuIndexType
(
type
);
}
else
{
}
else
{
WRAPPER_LOG_ERROR
<<
"Build IVFMIXIndex Failed"
;
WRAPPER_LOG_ERROR
<<
"Build IVFMIXIndex Failed"
;
return
KNOWHERE_ERROR
;
return
Status
(
KNOWHERE_ERROR
,
"Build IVFMIXIndex Failed"
)
;
}
}
}
catch
(
KnowhereException
&
e
)
{
}
catch
(
KnowhereException
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_UNEXPECTED_ERROR
;
return
Status
(
KNOWHERE_UNEXPECTED_ERROR
,
e
.
what
())
;
}
catch
(
jsoncons
::
json_exception
&
e
)
{
}
catch
(
jsoncons
::
json_exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_INVALID_ARGUMENT
;
return
Status
(
KNOWHERE_INVALID_ARGUMENT
,
e
.
what
())
;
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_ERROR
;
return
Status
(
KNOWHERE_ERROR
,
e
.
what
())
;
}
}
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
ErrorCode
IVFMixIndex
::
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
Status
IVFMixIndex
::
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
//index_ = std::make_shared<IVF>();
//index_ = std::make_shared<IVF>();
index_
->
Load
(
index_binary
);
index_
->
Load
(
index_binary
);
dim
=
Dimension
();
dim
=
Dimension
();
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
}
}
...
...
cpp/src/wrapper/vec_impl.h
浏览文件 @
b2003e9f
...
@@ -19,7 +19,6 @@
...
@@ -19,7 +19,6 @@
#pragma once
#pragma once
#include "knowhere/index/vector_index/VectorIndex.h"
#include "knowhere/index/vector_index/VectorIndex.h"
#include "vec_index.h"
#include "vec_index.h"
...
@@ -31,27 +30,53 @@ class VecIndexImpl : public VecIndex {
...
@@ -31,27 +30,53 @@ class VecIndexImpl : public VecIndex {
public:
public:
explicit
VecIndexImpl
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
,
const
IndexType
&
type
)
explicit
VecIndexImpl
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
,
const
IndexType
&
type
)
:
index_
(
std
::
move
(
index
)),
type
(
type
)
{};
:
index_
(
std
::
move
(
index
)),
type
(
type
)
{};
ErrorCode
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
Status
const
long
*
ids
,
BuildAll
(
const
long
&
nb
,
const
Config
&
cfg
,
const
float
*
xb
,
const
long
&
nt
,
const
long
*
ids
,
const
float
*
xt
)
override
;
const
Config
&
cfg
,
VecIndexPtr
CopyToGpu
(
const
int64_t
&
device_id
,
const
Config
&
cfg
)
override
;
const
long
&
nt
,
VecIndexPtr
CopyToCpu
(
const
Config
&
cfg
)
override
;
const
float
*
xt
)
override
;
IndexType
GetType
()
override
;
int64_t
Dimension
()
override
;
VecIndexPtr
int64_t
Count
()
override
;
CopyToGpu
(
const
int64_t
&
device_id
,
const
Config
&
cfg
)
override
;
ErrorCode
Add
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
)
override
;
zilliz
::
knowhere
::
BinarySet
Serialize
()
override
;
VecIndexPtr
ErrorCode
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
override
;
CopyToCpu
(
const
Config
&
cfg
)
override
;
VecIndexPtr
Clone
()
override
;
int64_t
GetDeviceId
()
override
;
IndexType
ErrorCode
Search
(
const
long
&
nq
,
const
float
*
xq
,
float
*
dist
,
long
*
ids
,
const
Config
&
cfg
)
override
;
GetType
()
override
;
int64_t
Dimension
()
override
;
int64_t
Count
()
override
;
Status
Add
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
)
override
;
zilliz
::
knowhere
::
BinarySet
Serialize
()
override
;
Status
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
override
;
VecIndexPtr
Clone
()
override
;
int64_t
GetDeviceId
()
override
;
Status
Search
(
const
long
&
nq
,
const
float
*
xq
,
float
*
dist
,
long
*
ids
,
const
Config
&
cfg
)
override
;
protected:
protected:
int64_t
dim
=
0
;
int64_t
dim
=
0
;
IndexType
type
=
IndexType
::
INVALID
;
IndexType
type
=
IndexType
::
INVALID
;
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index_
=
nullptr
;
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index_
=
nullptr
;
};
};
...
@@ -60,28 +85,39 @@ class IVFMixIndex : public VecIndexImpl {
...
@@ -60,28 +85,39 @@ class IVFMixIndex : public VecIndexImpl {
explicit
IVFMixIndex
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
,
const
IndexType
&
type
)
explicit
IVFMixIndex
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
,
const
IndexType
&
type
)
:
VecIndexImpl
(
std
::
move
(
index
),
type
)
{};
:
VecIndexImpl
(
std
::
move
(
index
),
type
)
{};
ErrorCode
BuildAll
(
const
long
&
nb
,
Status
const
float
*
xb
,
BuildAll
(
const
long
&
nb
,
const
long
*
ids
,
const
float
*
xb
,
const
Config
&
cfg
,
const
long
*
ids
,
const
long
&
nt
,
const
Config
&
cfg
,
const
float
*
xt
)
override
;
const
long
&
nt
,
ErrorCode
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
override
;
const
float
*
xt
)
override
;
Status
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
override
;
};
};
class
BFIndex
:
public
VecIndexImpl
{
class
BFIndex
:
public
VecIndexImpl
{
public:
public:
explicit
BFIndex
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
)
:
VecIndexImpl
(
std
::
move
(
index
),
explicit
BFIndex
(
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
)
:
VecIndexImpl
(
std
::
move
(
index
),
IndexType
::
FAISS_IDMAP
)
{};
IndexType
::
FAISS_IDMAP
)
{};
ErrorCode
Build
(
const
Config
&
cfg
);
float
*
GetRawVectors
();
ErrorCode
ErrorCode
BuildAll
(
const
long
&
nb
,
Build
(
const
Config
&
cfg
);
const
float
*
xb
,
const
long
*
ids
,
float
*
const
Config
&
cfg
,
GetRawVectors
();
const
long
&
nt
,
const
float
*
xt
)
override
;
Status
int64_t
*
GetRawIds
();
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
const
long
*
ids
,
const
Config
&
cfg
,
const
long
&
nt
,
const
float
*
xt
)
override
;
int64_t
*
GetRawIds
();
};
};
}
}
...
...
cpp/src/wrapper/vec_index.cpp
浏览文件 @
b2003e9f
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Exception.h"
#include "vec_index.h"
#include "vec_index.h"
#include "vec_impl.h"
#include "vec_impl.h"
#include "utils/Log.h"
#include "utils/Log.h"
...
@@ -39,23 +38,19 @@ namespace engine {
...
@@ -39,23 +38,19 @@ namespace engine {
static
constexpr
float
TYPICAL_COUNT
=
1000000.0
;
static
constexpr
float
TYPICAL_COUNT
=
1000000.0
;
struct
FileIOWriter
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOWriter
(
const
std
::
string
&
fname
);
~
FileIOWriter
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
};
struct
FileIOReader
{
struct
FileIOReader
{
std
::
fstream
fs
;
std
::
fstream
fs
;
std
::
string
name
;
std
::
string
name
;
FileIOReader
(
const
std
::
string
&
fname
);
FileIOReader
(
const
std
::
string
&
fname
);
~
FileIOReader
();
~
FileIOReader
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
size_t
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
);
size_t
operator
()(
void
*
ptr
,
size_t
size
);
size_t
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
);
};
};
FileIOReader
::
FileIOReader
(
const
std
::
string
&
fname
)
{
FileIOReader
::
FileIOReader
(
const
std
::
string
&
fname
)
{
...
@@ -67,14 +62,27 @@ FileIOReader::~FileIOReader() {
...
@@ -67,14 +62,27 @@ FileIOReader::~FileIOReader() {
fs
.
close
();
fs
.
close
();
}
}
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
)
{
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
read
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
fs
.
read
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
}
}
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
)
{
size_t
FileIOReader
::
operator
()(
void
*
ptr
,
size_t
size
,
size_t
pos
)
{
return
0
;
return
0
;
}
}
struct
FileIOWriter
{
std
::
fstream
fs
;
std
::
string
name
;
FileIOWriter
(
const
std
::
string
&
fname
);
~
FileIOWriter
();
size_t
operator
()(
void
*
ptr
,
size_t
size
);
};
FileIOWriter
::
FileIOWriter
(
const
std
::
string
&
fname
)
{
FileIOWriter
::
FileIOWriter
(
const
std
::
string
&
fname
)
{
name
=
fname
;
name
=
fname
;
fs
=
std
::
fstream
(
name
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
fs
=
std
::
fstream
(
name
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
...
@@ -84,12 +92,14 @@ FileIOWriter::~FileIOWriter() {
...
@@ -84,12 +92,14 @@ FileIOWriter::~FileIOWriter() {
fs
.
close
();
fs
.
close
();
}
}
size_t
FileIOWriter
::
operator
()(
void
*
ptr
,
size_t
size
)
{
size_t
FileIOWriter
::
operator
()(
void
*
ptr
,
size_t
size
)
{
fs
.
write
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
fs
.
write
(
reinterpret_cast
<
char
*>
(
ptr
),
size
);
}
}
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
)
{
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
)
{
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
;
std
::
shared_ptr
<
zilliz
::
knowhere
::
VectorIndex
>
index
;
auto
gpu_device
=
cfg
.
get_with_default
(
"gpu_id"
,
0
);
auto
gpu_device
=
cfg
.
get_with_default
(
"gpu_id"
,
0
);
switch
(
type
)
{
switch
(
type
)
{
...
@@ -145,13 +155,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
...
@@ -145,13 +155,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
return
std
::
make_shared
<
VecIndexImpl
>
(
index
,
type
);
return
std
::
make_shared
<
VecIndexImpl
>
(
index
,
type
);
}
}
VecIndexPtr
LoadVecIndex
(
const
IndexType
&
index_type
,
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
VecIndexPtr
LoadVecIndex
(
const
IndexType
&
index_type
,
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
auto
index
=
GetVecIndexFactory
(
index_type
);
auto
index
=
GetVecIndexFactory
(
index_type
);
index
->
Load
(
index_binary
);
index
->
Load
(
index_binary
);
return
index
;
return
index
;
}
}
VecIndexPtr
read_index
(
const
std
::
string
&
location
)
{
VecIndexPtr
read_index
(
const
std
::
string
&
location
)
{
knowhere
::
BinarySet
load_data_list
;
knowhere
::
BinarySet
load_data_list
;
FileIOReader
reader
(
location
);
FileIOReader
reader
(
location
);
reader
.
fs
.
seekg
(
0
,
reader
.
fs
.
end
);
reader
.
fs
.
seekg
(
0
,
reader
.
fs
.
end
);
...
@@ -195,7 +207,8 @@ VecIndexPtr read_index(const std::string &location) {
...
@@ -195,7 +207,8 @@ VecIndexPtr read_index(const std::string &location) {
return
LoadVecIndex
(
current_type
,
load_data_list
);
return
LoadVecIndex
(
current_type
,
load_data_list
);
}
}
ErrorCode
write_index
(
VecIndexPtr
index
,
const
std
::
string
&
location
)
{
Status
write_index
(
VecIndexPtr
index
,
const
std
::
string
&
location
)
{
try
{
try
{
auto
binaryset
=
index
->
Serialize
();
auto
binaryset
=
index
->
Serialize
();
auto
index_type
=
index
->
GetType
();
auto
index_type
=
index
->
GetType
();
...
@@ -215,28 +228,29 @@ ErrorCode write_index(VecIndexPtr index, const std::string &location) {
...
@@ -215,28 +228,29 @@ ErrorCode write_index(VecIndexPtr index, const std::string &location) {
}
}
}
catch
(
knowhere
::
KnowhereException
&
e
)
{
}
catch
(
knowhere
::
KnowhereException
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
return
KNOWHERE_UNEXPECTED_ERROR
;
return
Status
(
KNOWHERE_UNEXPECTED_ERROR
,
e
.
what
())
;
}
catch
(
std
::
exception
&
e
)
{
}
catch
(
std
::
exception
&
e
)
{
WRAPPER_LOG_ERROR
<<
e
.
what
();
WRAPPER_LOG_ERROR
<<
e
.
what
();
std
::
string
estring
(
e
.
what
());
std
::
string
estring
(
e
.
what
());
if
(
estring
.
find
(
"No space left on device"
)
!=
estring
.
npos
)
{
if
(
estring
.
find
(
"No space left on device"
)
!=
estring
.
npos
)
{
WRAPPER_LOG_ERROR
<<
"No space left on the device"
;
WRAPPER_LOG_ERROR
<<
"No space left on the device"
;
return
KNOWHERE_NO_SPACE
;
return
Status
(
KNOWHERE_NO_SPACE
,
"No space left on the device"
)
;
}
else
{
}
else
{
return
KNOWHERE_ERROR
;
return
Status
(
KNOWHERE_ERROR
,
e
.
what
())
;
}
}
}
}
return
KNOWHERE_SUCCESS
;
return
Status
::
OK
()
;
}
}
// TODO(linxj): redo here.
// TODO(linxj): redo here.
void
AutoGenParams
(
const
IndexType
&
type
,
const
long
&
size
,
zilliz
::
knowhere
::
Config
&
cfg
)
{
void
AutoGenParams
(
const
IndexType
&
type
,
const
long
&
size
,
zilliz
::
knowhere
::
Config
&
cfg
)
{
auto
nlist
=
cfg
.
get_with_default
(
"nlist"
,
0
);
auto
nlist
=
cfg
.
get_with_default
(
"nlist"
,
0
);
if
(
size
<=
TYPICAL_COUNT
/
16384
+
1
)
{
if
(
size
<=
TYPICAL_COUNT
/
16384
+
1
)
{
//handle less row count, avoid nlist set to 0
//handle less row count, avoid nlist set to 0
cfg
[
"nlist"
]
=
1
;
cfg
[
"nlist"
]
=
1
;
}
else
if
(
int
(
size
/
TYPICAL_COUNT
)
*
nlist
==
0
)
{
}
else
if
(
int
(
size
/
TYPICAL_COUNT
)
*
nlist
==
0
)
{
//calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
//calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
cfg
[
"nlist"
]
=
int
(
size
/
TYPICAL_COUNT
*
16384
);
cfg
[
"nlist"
]
=
int
(
size
/
TYPICAL_COUNT
*
16384
);
}
}
...
@@ -270,7 +284,8 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co
...
@@ -270,7 +284,8 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co
#define GPU_MAX_NRPOBE 1024
#define GPU_MAX_NRPOBE 1024
#endif
#endif
void
ParameterValidation
(
const
IndexType
&
type
,
Config
&
cfg
)
{
void
ParameterValidation
(
const
IndexType
&
type
,
Config
&
cfg
)
{
switch
(
type
)
{
switch
(
type
)
{
case
IndexType
::
FAISS_IVFSQ8_GPU
:
case
IndexType
::
FAISS_IVFSQ8_GPU
:
case
IndexType
::
FAISS_IVFFLAT_GPU
:
case
IndexType
::
FAISS_IVFFLAT_GPU
:
...
@@ -291,7 +306,8 @@ void ParameterValidation(const IndexType &type, Config &cfg) {
...
@@ -291,7 +306,8 @@ void ParameterValidation(const IndexType &type, Config &cfg) {
}
}
}
}
IndexType
ConvertToCpuIndexType
(
const
IndexType
&
type
)
{
IndexType
ConvertToCpuIndexType
(
const
IndexType
&
type
)
{
// TODO(linxj): add IDMAP
// TODO(linxj): add IDMAP
switch
(
type
)
{
switch
(
type
)
{
case
IndexType
::
FAISS_IVFFLAT_GPU
:
case
IndexType
::
FAISS_IVFFLAT_GPU
:
...
@@ -308,7 +324,8 @@ IndexType ConvertToCpuIndexType(const IndexType &type) {
...
@@ -308,7 +324,8 @@ IndexType ConvertToCpuIndexType(const IndexType &type) {
}
}
}
}
IndexType
ConvertToGpuIndexType
(
const
IndexType
&
type
)
{
IndexType
ConvertToGpuIndexType
(
const
IndexType
&
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
IndexType
::
FAISS_IVFFLAT_MIX
:
case
IndexType
::
FAISS_IVFFLAT_MIX
:
case
IndexType
::
FAISS_IVFFLAT_CPU
:
{
case
IndexType
::
FAISS_IVFFLAT_CPU
:
{
...
...
cpp/src/wrapper/vec_index.h
浏览文件 @
b2003e9f
...
@@ -21,8 +21,7 @@
...
@@ -21,8 +21,7 @@
#include <string>
#include <string>
#include <memory>
#include <memory>
#include "utils/Error.h"
#include "utils/Status.h"
#include "knowhere/common/Config.h"
#include "knowhere/common/Config.h"
#include "knowhere/common/BinarySet.h"
#include "knowhere/common/BinarySet.h"
...
@@ -50,62 +49,84 @@ enum class IndexType {
...
@@ -50,62 +49,84 @@ enum class IndexType {
};
};
class
VecIndex
;
class
VecIndex
;
using
VecIndexPtr
=
std
::
shared_ptr
<
VecIndex
>
;
using
VecIndexPtr
=
std
::
shared_ptr
<
VecIndex
>
;
class
VecIndex
{
class
VecIndex
{
public:
public:
virtual
ErrorCode
BuildAll
(
const
long
&
nb
,
virtual
Status
const
float
*
xb
,
BuildAll
(
const
long
&
nb
,
const
long
*
ids
,
const
float
*
xb
,
const
Config
&
cfg
,
const
long
*
ids
,
const
long
&
nt
=
0
,
const
Config
&
cfg
,
const
float
*
xt
=
nullptr
)
=
0
;
const
long
&
nt
=
0
,
const
float
*
xt
=
nullptr
)
=
0
;
virtual
ErrorCode
Add
(
const
long
&
nb
,
const
float
*
xb
,
virtual
Status
const
long
*
ids
,
Add
(
const
long
&
nb
,
const
Config
&
cfg
=
Config
())
=
0
;
const
float
*
xb
,
const
long
*
ids
,
virtual
ErrorCode
Search
(
const
long
&
nq
,
const
Config
&
cfg
=
Config
())
=
0
;
const
float
*
xq
,
float
*
dist
,
virtual
Status
long
*
ids
,
Search
(
const
long
&
nq
,
const
Config
&
cfg
=
Config
())
=
0
;
const
float
*
xq
,
float
*
dist
,
virtual
VecIndexPtr
CopyToGpu
(
const
int64_t
&
device_id
,
long
*
ids
,
const
Config
&
cfg
=
Config
())
=
0
;
const
Config
&
cfg
=
Config
())
=
0
;
virtual
VecIndexPtr
CopyToCpu
(
const
Config
&
cfg
=
Config
())
=
0
;
virtual
VecIndexPtr
CopyToGpu
(
const
int64_t
&
device_id
,
virtual
VecIndexPtr
Clone
()
=
0
;
const
Config
&
cfg
=
Config
())
=
0
;
virtual
int64_t
GetDeviceId
()
=
0
;
virtual
VecIndexPtr
CopyToCpu
(
const
Config
&
cfg
=
Config
())
=
0
;
virtual
IndexType
GetType
()
=
0
;
virtual
VecIndexPtr
virtual
int64_t
Dimension
()
=
0
;
Clone
()
=
0
;
virtual
int64_t
Count
()
=
0
;
virtual
int64_t
GetDeviceId
()
=
0
;
virtual
zilliz
::
knowhere
::
BinarySet
Serialize
()
=
0
;
virtual
IndexType
virtual
ErrorCode
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
=
0
;
GetType
()
=
0
;
virtual
int64_t
Dimension
()
=
0
;
virtual
int64_t
Count
()
=
0
;
virtual
zilliz
::
knowhere
::
BinarySet
Serialize
()
=
0
;
virtual
Status
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
=
0
;
};
};
extern
ErrorCode
write_index
(
VecIndexPtr
index
,
const
std
::
string
&
location
);
extern
Status
write_index
(
VecIndexPtr
index
,
const
std
::
string
&
location
);
extern
VecIndexPtr
read_index
(
const
std
::
string
&
location
);
extern
VecIndexPtr
read_index
(
const
std
::
string
&
location
);
extern
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
=
Config
());
extern
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
=
Config
());
extern
VecIndexPtr
LoadVecIndex
(
const
IndexType
&
index_type
,
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
);
extern
VecIndexPtr
LoadVecIndex
(
const
IndexType
&
index_type
,
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
);
extern
void
AutoGenParams
(
const
IndexType
&
type
,
const
long
&
size
,
Config
&
cfg
);
extern
void
AutoGenParams
(
const
IndexType
&
type
,
const
long
&
size
,
Config
&
cfg
);
extern
void
ParameterValidation
(
const
IndexType
&
type
,
Config
&
cfg
);
extern
void
ParameterValidation
(
const
IndexType
&
type
,
Config
&
cfg
);
extern
IndexType
ConvertToCpuIndexType
(
const
IndexType
&
type
);
extern
IndexType
ConvertToCpuIndexType
(
const
IndexType
&
type
);
extern
IndexType
extern
IndexType
ConvertToGpuIndexType
(
const
IndexType
&
type
);
ConvertToGpuIndexType
(
const
IndexType
&
type
);
}
}
}
}
...
...
cpp/unittest/scheduler/scheduler_test.cpp
浏览文件 @
b2003e9f
...
@@ -17,8 +17,7 @@
...
@@ -17,8 +17,7 @@
#include "scheduler/Scheduler.h"
#include "scheduler/Scheduler.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <src/scheduler/tasklabel/DefaultLabel.h>
#include "src/scheduler/tasklabel/DefaultLabel.h"
#include <src/server/ServerConfig.h>
#include "cache/DataObj.h"
#include "cache/DataObj.h"
#include "cache/GpuCacheMgr.h"
#include "cache/GpuCacheMgr.h"
#include "scheduler/task/TestTask.h"
#include "scheduler/task/TestTask.h"
...
@@ -35,13 +34,12 @@ namespace engine {
...
@@ -35,13 +34,12 @@ namespace engine {
class
MockVecIndex
:
public
engine
::
VecIndex
{
class
MockVecIndex
:
public
engine
::
VecIndex
{
public:
public:
virtual
ErrorCode
BuildAll
(
const
long
&
nb
,
virtual
Status
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
const
float
*
xb
,
const
long
*
ids
,
const
long
*
ids
,
const
engine
::
Config
&
cfg
,
const
engine
::
Config
&
cfg
,
const
long
&
nt
=
0
,
const
long
&
nt
=
0
,
const
float
*
xt
=
nullptr
)
{
const
float
*
xt
=
nullptr
)
{
}
}
engine
::
VecIndexPtr
Clone
()
override
{
engine
::
VecIndexPtr
Clone
()
override
{
...
@@ -56,14 +54,14 @@ public:
...
@@ -56,14 +54,14 @@ public:
return
engine
::
IndexType
::
INVALID
;
return
engine
::
IndexType
::
INVALID
;
}
}
virtual
ErrorCode
Add
(
const
long
&
nb
,
virtual
Status
Add
(
const
long
&
nb
,
const
float
*
xb
,
const
float
*
xb
,
const
long
*
ids
,
const
long
*
ids
,
const
engine
::
Config
&
cfg
=
engine
::
Config
())
{
const
engine
::
Config
&
cfg
=
engine
::
Config
())
{
}
}
virtual
ErrorCode
Search
(
const
long
&
nq
,
virtual
Status
Search
(
const
long
&
nq
,
const
float
*
xq
,
const
float
*
xq
,
float
*
dist
,
float
*
dist
,
long
*
ids
,
long
*
ids
,
...
@@ -92,7 +90,7 @@ public:
...
@@ -92,7 +90,7 @@ public:
return
binset
;
return
binset
;
}
}
virtual
ErrorCode
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
virtual
Status
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
}
}
...
...
cpp/unittest/server/cache_test.cpp
浏览文件 @
b2003e9f
...
@@ -19,7 +19,6 @@
...
@@ -19,7 +19,6 @@
#include "cache/CpuCacheMgr.h"
#include "cache/CpuCacheMgr.h"
#include "cache/GpuCacheMgr.h"
#include "cache/GpuCacheMgr.h"
#include "server/ServerConfig.h"
#include "server/ServerConfig.h"
#include "utils/Error.h"
#include "utils/Error.h"
#include "src/wrapper/vec_index.h"
#include "src/wrapper/vec_index.h"
...
@@ -48,13 +47,13 @@ public:
...
@@ -48,13 +47,13 @@ public:
}
}
virtual
ErrorCode
BuildAll
(
const
long
&
nb
,
virtual
Status
BuildAll
(
const
long
&
nb
,
const
float
*
xb
,
const
float
*
xb
,
const
long
*
ids
,
const
long
*
ids
,
const
engine
::
Config
&
cfg
,
const
engine
::
Config
&
cfg
,
const
long
&
nt
=
0
,
const
long
&
nt
=
0
,
const
float
*
xt
=
nullptr
)
{
const
float
*
xt
=
nullptr
)
{
return
0
;
return
Status
()
;
}
}
engine
::
VecIndexPtr
Clone
()
override
{
engine
::
VecIndexPtr
Clone
()
override
{
...
@@ -69,19 +68,19 @@ public:
...
@@ -69,19 +68,19 @@ public:
return
engine
::
IndexType
::
INVALID
;
return
engine
::
IndexType
::
INVALID
;
}
}
virtual
ErrorCode
Add
(
const
long
&
nb
,
virtual
Status
Add
(
const
long
&
nb
,
const
float
*
xb
,
const
float
*
xb
,
const
long
*
ids
,
const
long
*
ids
,
const
engine
::
Config
&
cfg
=
engine
::
Config
())
{
const
engine
::
Config
&
cfg
=
engine
::
Config
())
{
return
0
;
return
Status
()
;
}
}
virtual
ErrorCode
Search
(
const
long
&
nq
,
virtual
Status
Search
(
const
long
&
nq
,
const
float
*
xq
,
const
float
*
xq
,
float
*
dist
,
float
*
dist
,
long
*
ids
,
long
*
ids
,
const
engine
::
Config
&
cfg
=
engine
::
Config
())
{
const
engine
::
Config
&
cfg
=
engine
::
Config
())
{
return
0
;
return
Status
()
;
}
}
engine
::
VecIndexPtr
CopyToGpu
(
const
int64_t
&
device_id
,
engine
::
VecIndexPtr
CopyToGpu
(
const
int64_t
&
device_id
,
...
@@ -106,8 +105,8 @@ public:
...
@@ -106,8 +105,8 @@ public:
return
binset
;
return
binset
;
}
}
virtual
ErrorCode
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
virtual
Status
Load
(
const
zilliz
::
knowhere
::
BinarySet
&
index_binary
)
{
return
0
;
return
Status
()
;
}
}
public:
public:
...
...
cpp/unittest/wrapper/CMakeLists.txt
浏览文件 @
b2003e9f
...
@@ -28,6 +28,7 @@ set(wrapper_files
...
@@ -28,6 +28,7 @@ set(wrapper_files
set
(
util_files
set
(
util_files
utils.cpp
utils.cpp
${
MILVUS_ENGINE_SRC
}
/utils/easylogging++.cc
${
MILVUS_ENGINE_SRC
}
/utils/easylogging++.cc
${
MILVUS_ENGINE_SRC
}
/utils/Status.cpp
)
)
set
(
knowhere_libs
set
(
knowhere_libs
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录