Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
milvus
milvus
提交
bfd4c6c5
M
milvus
项目概览
milvus
/
milvus
11 个月 前同步成功
通知
261
Star
22476
Fork
2472
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
bfd4c6c5
编写于
11月 20, 2019
作者:
X
xiaojun.lin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
upgrade SPTAG and support KDT and BKT
上级
0aa90ea9
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
883 addition
and
292 deletion
+883
-292
core/src/db/engine/ExecutionEngine.h
core/src/db/engine/ExecutionEngine.h
+3
-1
core/src/db/engine/ExecutionEngineImpl.cpp
core/src/db/engine/ExecutionEngineImpl.cpp
+8
-0
core/src/index/knowhere/CMakeLists.txt
core/src/index/knowhere/CMakeLists.txt
+2
-2
core/src/index/knowhere/knowhere/index/vector_index/IndexKDT.cpp
...c/index/knowhere/knowhere/index/vector_index/IndexKDT.cpp
+0
-180
core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
...index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
+348
-0
core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h
...c/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h
+93
-0
core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h
...here/knowhere/index/vector_index/helpers/IndexParameter.h
+77
-2
core/src/index/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp
...e/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp
+0
-55
core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
...knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
+75
-0
core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.h
...e/knowhere/index/vector_index/helpers/SPTAGParameterMgr.h
+28
-19
core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h
...dex/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h
+1
-1
core/src/index/unittest/CMakeLists.txt
core/src/index/unittest/CMakeLists.txt
+8
-8
core/src/index/unittest/test_sptag.cpp
core/src/index/unittest/test_sptag.cpp
+154
-0
core/src/index/unittest/utils.cpp
core/src/index/unittest/utils.cpp
+8
-6
core/src/wrapper/ConfAdapter.cpp
core/src/wrapper/ConfAdapter.cpp
+30
-0
core/src/wrapper/ConfAdapter.h
core/src/wrapper/ConfAdapter.h
+18
-0
core/src/wrapper/ConfAdapterMgr.cpp
core/src/wrapper/ConfAdapterMgr.cpp
+3
-0
core/src/wrapper/VecIndex.cpp
core/src/wrapper/VecIndex.cpp
+6
-2
core/src/wrapper/VecIndex.h
core/src/wrapper/VecIndex.h
+4
-0
core/unittest/wrapper/test_wrapper.cpp
core/unittest/wrapper/test_wrapper.cpp
+17
-16
未找到文件。
core/src/db/engine/ExecutionEngine.h
浏览文件 @
bfd4c6c5
...
@@ -35,7 +35,9 @@ enum class EngineType {
...
@@ -35,7 +35,9 @@ enum class EngineType {
NSG_MIX
,
NSG_MIX
,
FAISS_IVFSQ8H
,
FAISS_IVFSQ8H
,
FAISS_PQ
,
FAISS_PQ
,
MAX_VALUE
=
FAISS_PQ
,
SPTAG_KDT
,
SPTAG_BKT
,
MAX_VALUE
=
SPTAG_BKT
,
};
};
enum
class
MetricType
{
enum
class
MetricType
{
...
...
core/src/db/engine/ExecutionEngineImpl.cpp
浏览文件 @
bfd4c6c5
...
@@ -124,6 +124,14 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
...
@@ -124,6 +124,14 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
#endif
#endif
break
;
break
;
}
}
case
EngineType
::
SPTAG_KDT
:
{
index
=
GetVecIndexFactory
(
IndexType
::
SPTAG_KDT_RNT_CPU
);
break
;
}
case
EngineType
::
SPTAG_BKT
:
{
index
=
GetVecIndexFactory
(
IndexType
::
SPTAG_BKT_RNT_CPU
);
break
;
}
default:
{
default:
{
ENGINE_LOG_ERROR
<<
"Unsupported index type"
;
ENGINE_LOG_ERROR
<<
"Unsupported index type"
;
return
nullptr
;
return
nullptr
;
...
...
core/src/index/knowhere/CMakeLists.txt
浏览文件 @
bfd4c6c5
...
@@ -30,10 +30,10 @@ set(external_srcs
...
@@ -30,10 +30,10 @@ set(external_srcs
set
(
index_srcs
set
(
index_srcs
knowhere/index/preprocessor/Normalize.cpp
knowhere/index/preprocessor/Normalize.cpp
knowhere/index/vector_index/Index
KDT
.cpp
knowhere/index/vector_index/Index
SPTAG
.cpp
knowhere/index/vector_index/IndexIDMAP.cpp
knowhere/index/vector_index/IndexIDMAP.cpp
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/helpers/
KDT
ParameterMgr.cpp
knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.cpp
knowhere/index/vector_index/IndexNSG.cpp
knowhere/index/vector_index/IndexNSG.cpp
knowhere/index/vector_index/nsg/NSG.cpp
knowhere/index/vector_index/nsg/NSG.cpp
knowhere/index/vector_index/nsg/NSGIO.cpp
knowhere/index/vector_index/nsg/NSGIO.cpp
...
...
core/src/index/knowhere/knowhere/index/vector_index/IndexKDT.cpp
已删除
100644 → 0
浏览文件 @
0aa90ea9
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <SPTAG/AnnService/inc/Core/Common.h>
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
#include <sstream>
#include <vector>
#undef mkdir
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
//#include "knowhere/index/preprocessor/normalize.h"
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
namespace
knowhere
{
BinarySet
CPUKDTRNG
::
Serialize
()
{
std
::
vector
<
void
*>
index_blobs
;
std
::
vector
<
int64_t
>
index_len
;
// TODO(zirui): dev
// index_ptr_->SaveIndexToMemory(index_blobs, index_len);
BinarySet
binary_set
;
//
// auto sample = std::make_shared<uint8_t>();
// sample.reset(static_cast<uint8_t*>(index_blobs[0]));
// auto tree = std::make_shared<uint8_t>();
// tree.reset(static_cast<uint8_t*>(index_blobs[1]));
// auto graph = std::make_shared<uint8_t>();
// graph.reset(static_cast<uint8_t*>(index_blobs[2]));
// auto metadata = std::make_shared<uint8_t>();
// metadata.reset(static_cast<uint8_t*>(index_blobs[3]));
//
// binary_set.Append("samples", sample, index_len[0]);
// binary_set.Append("tree", tree, index_len[1]);
// binary_set.Append("graph", graph, index_len[2]);
// binary_set.Append("metadata", metadata, index_len[3]);
return
binary_set
;
}
void
CPUKDTRNG
::
Load
(
const
BinarySet
&
binary_set
)
{
// TODO(zirui): dev
// std::vector<void*> index_blobs;
//
// auto samples = binary_set.GetByName("samples");
// index_blobs.push_back(samples->data.get());
//
// auto tree = binary_set.GetByName("tree");
// index_blobs.push_back(tree->data.get());
//
// auto graph = binary_set.GetByName("graph");
// index_blobs.push_back(graph->data.get());
//
// auto metadata = binary_set.GetByName("metadata");
// index_blobs.push_back(metadata->data.get());
//
// index_ptr_->LoadIndexFromMemory(index_blobs);
}
// PreprocessorPtr
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
// return std::make_shared<NormalizePreprocessor>();
//}
IndexModelPtr
CPUKDTRNG
::
Train
(
const
DatasetPtr
&
origin
,
const
Config
&
train_config
)
{
SetParameters
(
train_config
);
DatasetPtr
dataset
=
origin
->
Clone
();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto
vectorset
=
ConvertToVectorSet
(
dataset
);
auto
metaset
=
ConvertToMetadataSet
(
dataset
);
index_ptr_
->
BuildIndex
(
vectorset
,
metaset
);
// TODO: return IndexModelPtr
return
nullptr
;
}
void
CPUKDTRNG
::
Add
(
const
DatasetPtr
&
origin
,
const
Config
&
add_config
)
{
SetParameters
(
add_config
);
DatasetPtr
dataset
=
origin
->
Clone
();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto
vectorset
=
ConvertToVectorSet
(
dataset
);
auto
metaset
=
ConvertToMetadataSet
(
dataset
);
index_ptr_
->
AddIndex
(
vectorset
,
metaset
);
}
void
CPUKDTRNG
::
SetParameters
(
const
Config
&
config
)
{
for
(
auto
&
para
:
KDTParameterMgr
::
GetInstance
().
GetKDTParameters
())
{
// auto value = config.get_with_default(para.first, para.second);
index_ptr_
->
SetParameter
(
para
.
first
,
para
.
second
);
}
}
DatasetPtr
CPUKDTRNG
::
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
SetParameters
(
config
);
auto
tensor
=
dataset
->
tensor
()[
0
];
auto
p
=
(
float
*
)
tensor
->
raw_mutable_data
();
for
(
auto
i
=
0
;
i
<
10
;
++
i
)
{
for
(
auto
j
=
0
;
j
<
10
;
++
j
)
{
std
::
cout
<<
p
[
i
*
10
+
j
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
vector
<
SPTAG
::
QueryResult
>
query_results
=
ConvertToQueryResult
(
dataset
,
config
);
#pragma omp parallel for
for
(
auto
i
=
0
;
i
<
query_results
.
size
();
++
i
)
{
auto
target
=
(
float
*
)
query_results
[
i
].
GetTarget
();
std
::
cout
<<
target
[
0
]
<<
", "
<<
target
[
1
]
<<
", "
<<
target
[
2
]
<<
std
::
endl
;
index_ptr_
->
SearchIndex
(
query_results
[
i
]);
}
return
ConvertToDataset
(
query_results
);
}
int64_t
CPUKDTRNG
::
Count
()
{
index_ptr_
->
GetNumSamples
();
}
int64_t
CPUKDTRNG
::
Dimension
()
{
index_ptr_
->
GetFeatureDim
();
}
VectorIndexPtr
CPUKDTRNG
::
Clone
()
{
KNOWHERE_THROW_MSG
(
"not support"
);
}
void
CPUKDTRNG
::
Seal
()
{
// do nothing
}
// TODO(linxj):
BinarySet
CPUKDTRNGIndexModel
::
Serialize
()
{
}
void
CPUKDTRNGIndexModel
::
Load
(
const
BinarySet
&
binary
)
{
}
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
0 → 100644
浏览文件 @
bfd4c6c5
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <SPTAG/AnnService/inc/Core/Common.h>
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
#include <sstream>
#include <vector>
#include <array>
#undef mkdir
#include "knowhere/index/vector_index/IndexSPTAG.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h"
namespace
knowhere
{
CPUSPTAGRNG
::
CPUSPTAGRNG
(
const
std
::
string
&
IndexType
)
{
if
(
IndexType
==
"KDT"
)
{
index_ptr_
=
SPTAG
::
VectorIndex
::
CreateInstance
(
SPTAG
::
IndexAlgoType
::
KDT
,
SPTAG
::
VectorValueType
::
Float
);
index_ptr_
->
SetParameter
(
"DistCalcMethod"
,
"L2"
);
index_type_
=
SPTAG
::
IndexAlgoType
::
KDT
;
}
else
{
index_ptr_
=
SPTAG
::
VectorIndex
::
CreateInstance
(
SPTAG
::
IndexAlgoType
::
BKT
,
SPTAG
::
VectorValueType
::
Float
);
index_ptr_
->
SetParameter
(
"DistCalcMethod"
,
"L2"
);
index_type_
=
SPTAG
::
IndexAlgoType
::
BKT
;
}
}
BinarySet
CPUSPTAGRNG
::
Serialize
()
{
std
::
string
index_config
;
std
::
vector
<
SPTAG
::
ByteArray
>
index_blobs
;
std
::
shared_ptr
<
std
::
vector
<
std
::
uint64_t
>>
buffersize
=
index_ptr_
->
CalculateBufferSize
();
std
::
vector
<
char
*>
res
(
buffersize
->
size
()
+
1
);
for
(
uint64_t
i
=
1
;
i
<
res
.
size
();
i
++
)
{
res
[
i
]
=
new
char
[
buffersize
->
at
(
i
-
1
)];
auto
ptr
=
&
res
[
i
][
0
];
index_blobs
.
emplace_back
(
SPTAG
::
ByteArray
((
std
::
uint8_t
*
)
ptr
,
buffersize
->
at
(
i
-
1
),
false
));
}
index_ptr_
->
SaveIndex
(
index_config
,
index_blobs
);
size_t
length
=
index_config
.
length
();
char
*
cstr
=
new
char
[
length
];
snprintf
(
cstr
,
length
,
"%s"
,
index_config
.
c_str
());
BinarySet
binary_set
;
auto
sample
=
std
::
make_shared
<
uint8_t
>
();
sample
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
0
].
Data
()));
auto
tree
=
std
::
make_shared
<
uint8_t
>
();
tree
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
1
].
Data
()));
auto
graph
=
std
::
make_shared
<
uint8_t
>
();
graph
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
2
].
Data
()));
auto
deleteid
=
std
::
make_shared
<
uint8_t
>
();
deleteid
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
3
].
Data
()));
auto
metadata1
=
std
::
make_shared
<
uint8_t
>
();
metadata1
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
4
].
Data
()));
auto
metadata2
=
std
::
make_shared
<
uint8_t
>
();
metadata2
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
5
].
Data
()));
auto
config
=
std
::
make_shared
<
uint8_t
>
();
config
.
reset
(
static_cast
<
uint8_t
*>
((
void
*
)
cstr
));
binary_set
.
Append
(
"samples"
,
sample
,
index_blobs
[
0
].
Length
());
binary_set
.
Append
(
"tree"
,
tree
,
index_blobs
[
1
].
Length
());
binary_set
.
Append
(
"deleteid"
,
deleteid
,
index_blobs
[
3
].
Length
());
binary_set
.
Append
(
"metadata1"
,
metadata1
,
index_blobs
[
4
].
Length
());
binary_set
.
Append
(
"metadata2"
,
metadata2
,
index_blobs
[
5
].
Length
());
binary_set
.
Append
(
"config"
,
config
,
length
);
binary_set
.
Append
(
"graph"
,
graph
,
index_blobs
[
2
].
Length
());
// MemoryIOWriter writer;
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// len = index_blobs[i].Length();
// assert(len != 0);
// writer(&len, sizeof(size_t), 1);
// writer(index_blobs[i].Data(), len, 1);
// len = 0;
// }
// writer(&length, sizeof(size_t), 1);
// writer(cstr, length, 1);
// auto data = std::make_shared<uint8_t>();
// data.reset(writer.data_);
// BinarySet binary_set;
// binary_set.Append("sptag", data, writer.total);
// MemoryIOWriter writer;
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// if (i == 2) continue;
// len = index_blobs[i].Length();
// assert(len != 0);
// writer(&len, sizeof(size_t), 1);
// writer(index_blobs[i].Data(), len, 1);
// len = 0;
// }
// writer(&length, sizeof(size_t), 1);
// writer(cstr, length, 1);
// auto data = std::make_shared<uint8_t>();
// data.reset(writer.data_);
// BinarySet binary_set;
// binary_set.Append("sptag", data, writer.total);
// auto graph = std::make_shared<uint8_t>();
// graph.reset(static_cast<uint8_t*>(index_blobs[2].Data()));
// binary_set.Append("graph", graph, index_blobs[2].Length());
return
binary_set
;
}
void
CPUSPTAGRNG
::
Load
(
const
BinarySet
&
binary_set
)
{
std
::
string
index_config
;
std
::
vector
<
SPTAG
::
ByteArray
>
index_blobs
;
auto
samples
=
binary_set
.
GetByName
(
"samples"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
samples
->
data
.
get
(),
samples
->
size
,
false
));
auto
tree
=
binary_set
.
GetByName
(
"tree"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
tree
->
data
.
get
(),
tree
->
size
,
false
));
auto
graph
=
binary_set
.
GetByName
(
"graph"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
graph
->
data
.
get
(),
graph
->
size
,
false
));
auto
deleteid
=
binary_set
.
GetByName
(
"deleteid"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
deleteid
->
data
.
get
(),
deleteid
->
size
,
false
));
auto
metadata1
=
binary_set
.
GetByName
(
"metadata1"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
metadata1
->
data
.
get
(),
metadata1
->
size
,
false
));
auto
metadata2
=
binary_set
.
GetByName
(
"metadata2"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
metadata2
->
data
.
get
(),
metadata2
->
size
,
false
));
auto
config
=
binary_set
.
GetByName
(
"config"
);
index_config
=
reinterpret_cast
<
char
*>
(
config
->
data
.
get
());
// std::vector<SPTAG::ByteArray> index_blobs;
// auto data = binary_set.GetByName("sptag");
// MemoryIOReader reader;
// reader.total = data->size;
// reader.data_ = data->data.get();
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto binary = new uint8_t[len];
// reader(binary, len, 1);
// index_blobs.emplace_back(SPTAG::ByteArray(binary, len, true));
// len = 0;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto config = new char[len];
// reader(config, len, 1);
// std::string index_config = config;
// delete[] config;
// std::vector<SPTAG::ByteArray> index_blobs;
// auto data = binary_set.GetByName("sptag");
// MemoryIOReader reader;
// reader.total = data->size;
// reader.data_ = data->data.get();
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// if (i == 2) {
// auto graph = binary_set.GetByName("graph");
// index_blobs.emplace_back(SPTAG::ByteArray(graph->data.get(), graph->size, false));
// continue;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto binary = new uint8_t[len];
// reader(binary, len, 1);
// index_blobs.emplace_back(SPTAG::ByteArray(binary, len, true));
// len = 0;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto config = new char[len];
// reader(config, len, 1);
// std::string index_config = config;
// delete[] config;
index_ptr_
->
LoadIndex
(
index_config
,
index_blobs
);
}
// PreprocessorPtr
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
// return std::make_shared<NormalizePreprocessor>();
//}
IndexModelPtr
CPUSPTAGRNG
::
Train
(
const
DatasetPtr
&
origin
,
const
Config
&
train_config
)
{
SetParameters
(
train_config
);
DatasetPtr
dataset
=
origin
->
Clone
();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto
vectorset
=
ConvertToVectorSet
(
dataset
);
auto
metaset
=
ConvertToMetadataSet
(
dataset
);
index_ptr_
->
BuildIndex
(
vectorset
,
metaset
);
// TODO: return IndexModelPtr
return
nullptr
;
}
void
CPUSPTAGRNG
::
Add
(
const
DatasetPtr
&
origin
,
const
Config
&
add_config
)
{
// SetParameters(add_config);
// DatasetPtr dataset = origin->Clone();
//
// // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// // && preprocessor_) {
// // preprocessor_->Preprocess(dataset);
// //}
//
// auto vectorset = ConvertToVectorSet(dataset);
// auto metaset = ConvertToMetadataSet(dataset);
// index_ptr_->AddIndex(vectorset, metaset);
}
void
CPUSPTAGRNG
::
SetParameters
(
const
Config
&
config
)
{
#define Assign(param_name, str_name) \
conf->param_name == INVALID_VALUE ? index_ptr_->SetParameter(str_name, std::to_string(build_cfg->param_name)) \
: index_ptr_->SetParameter(str_name, std::to_string(conf->param_name))
if
(
index_type_
==
SPTAG
::
IndexAlgoType
::
KDT
)
{
auto
conf
=
std
::
dynamic_pointer_cast
<
KDTCfg
>
(
config
);
auto
build_cfg
=
SPTAGParameterMgr
::
GetInstance
().
GetKDTParameters
();
Assign
(
kdtnumber
,
"KDTNumber"
);
Assign
(
numtopdimensionkdtsplit
,
"NumTopDimensionKDTSplit"
);
Assign
(
samples
,
"Samples"
);
Assign
(
tptnumber
,
"TPTNumber"
);
Assign
(
tptleafsize
,
"TPTLeafSize"
);
Assign
(
numtopdimensiontptsplit
,
"NumTopDimensionTPTSplit"
);
Assign
(
neighborhoodsize
,
"NeighborhoodSize"
);
Assign
(
graphneighborhoodscale
,
"GraphNeighborhoodScale"
);
Assign
(
graphcefscale
,
"GraphCEFScale"
);
Assign
(
refineiterations
,
"RefineIterations"
);
Assign
(
cef
,
"CEF"
);
Assign
(
maxcheckforrefinegraph
,
"MaxCheckForRefineGraph"
);
Assign
(
numofthreads
,
"NumberOfThreads"
);
Assign
(
maxcheck
,
"MaxCheck"
);
Assign
(
thresholdofnumberofcontinuousnobetterpropagation
,
"ThresholdOfNumberOfContinuousNoBetterPropagation"
);
Assign
(
numberofinitialdynamicpivots
,
"NumberOfInitialDynamicPivots"
);
Assign
(
numberofotherdynamicpivots
,
"NumberOfOtherDynamicPivots"
);
}
else
{
auto
conf
=
std
::
dynamic_pointer_cast
<
BKTCfg
>
(
config
);
auto
build_cfg
=
SPTAGParameterMgr
::
GetInstance
().
GetBKTParameters
();
Assign
(
bktnumber
,
"BKTNumber"
);
Assign
(
bktkmeansk
,
"BKTKMeansK"
);
Assign
(
bktleafsize
,
"BKTLeafSize"
);
Assign
(
samples
,
"Samples"
);
Assign
(
tptnumber
,
"TPTNumber"
);
Assign
(
tptleafsize
,
"TPTLeafSize"
);
Assign
(
numtopdimensiontptsplit
,
"NumTopDimensionTPTSplit"
);
Assign
(
neighborhoodsize
,
"NeighborhoodSize"
);
Assign
(
graphneighborhoodscale
,
"GraphNeighborhoodScale"
);
Assign
(
graphcefscale
,
"GraphCEFScale"
);
Assign
(
refineiterations
,
"RefineIterations"
);
Assign
(
cef
,
"CEF"
);
Assign
(
maxcheckforrefinegraph
,
"MaxCheckForRefineGraph"
);
Assign
(
numofthreads
,
"NumberOfThreads"
);
Assign
(
maxcheck
,
"MaxCheck"
);
Assign
(
thresholdofnumberofcontinuousnobetterpropagation
,
"ThresholdOfNumberOfContinuousNoBetterPropagation"
);
Assign
(
numberofinitialdynamicpivots
,
"NumberOfInitialDynamicPivots"
);
Assign
(
numberofotherdynamicpivots
,
"NumberOfOtherDynamicPivots"
);
}
}
DatasetPtr
CPUSPTAGRNG
::
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
SetParameters
(
config
);
auto
tensor
=
dataset
->
tensor
()[
0
];
auto
p
=
(
float
*
)
tensor
->
raw_mutable_data
();
for
(
auto
i
=
0
;
i
<
10
;
++
i
)
{
for
(
auto
j
=
0
;
j
<
10
;
++
j
)
{
std
::
cout
<<
p
[
i
*
10
+
j
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
vector
<
SPTAG
::
QueryResult
>
query_results
=
ConvertToQueryResult
(
dataset
,
config
);
#pragma omp parallel for
for
(
auto
i
=
0
;
i
<
query_results
.
size
();
++
i
)
{
auto
target
=
(
float
*
)
query_results
[
i
].
GetTarget
();
std
::
cout
<<
target
[
0
]
<<
", "
<<
target
[
1
]
<<
", "
<<
target
[
2
]
<<
std
::
endl
;
index_ptr_
->
SearchIndex
(
query_results
[
i
]);
}
return
ConvertToDataset
(
query_results
);
}
int64_t
CPUSPTAGRNG
::
Count
()
{
return
index_ptr_
->
GetNumSamples
();
}
int64_t
CPUSPTAGRNG
::
Dimension
()
{
return
index_ptr_
->
GetFeatureDim
();
}
VectorIndexPtr
CPUSPTAGRNG
::
Clone
()
{
KNOWHERE_THROW_MSG
(
"not support"
);
}
void
CPUSPTAGRNG
::
Seal
()
{
return
;
// do nothing
}
BinarySet
CPUSPTAGRNGIndexModel
::
Serialize
()
{
// KNOWHERE_THROW_MSG("not support"); // not support
}
void
CPUSPTAGRNGIndexModel
::
Load
(
const
BinarySet
&
binary
)
{
// KNOWHERE_THROW_MSG("not support"); // not support
}
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/Index
KDT
.h
→
core/src/index/knowhere/knowhere/index/vector_index/Index
SPTAG
.h
浏览文件 @
bfd4c6c5
...
@@ -18,70 +18,76 @@
...
@@ -18,70 +18,76 @@
#pragma once
#pragma once
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
#include <cstdint>
#include <cstdint>
#include <memory>
#include <memory>
#include <string>
#include "VectorIndex.h"
#include "VectorIndex.h"
#include "knowhere/index/IndexModel.h"
#include "knowhere/index/IndexModel.h"
namespace
knowhere
{
namespace
knowhere
{
class
CPUKDTRNG
:
public
VectorIndex
{
class
CPUSPTAGRNG
:
public
VectorIndex
{
public:
public:
CPUKDTRNG
()
{
explicit
CPUSPTAGRNG
(
const
std
::
string
&
IndexType
);
index_ptr_
=
SPTAG
::
VectorIndex
::
CreateInstance
(
SPTAG
::
IndexAlgoType
::
KDT
,
SPTAG
::
VectorValueType
::
Float
);
index_ptr_
->
SetParameter
(
"DistCalcMethod"
,
"L2"
);
public:
}
BinarySet
Serialize
()
override
;
public:
BinarySet
VectorIndexPtr
Serialize
()
override
;
Clone
()
override
;
VectorIndexPtr
Clone
()
override
;
void
void
Load
(
const
BinarySet
&
index_array
)
override
;
Load
(
const
BinarySet
&
index_array
)
override
;
public:
public:
// PreprocessorPtr
// PreprocessorPtr
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
int64_t
int64_t
Count
()
override
;
Count
()
override
;
int64_t
Dimension
()
override
;
int64_t
Dimension
()
override
;
IndexModelPtr
Train
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
IndexModelPtr
Train
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
void
Add
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
void
Add
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
DatasetPtr
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
DatasetPtr
void
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
Seal
()
override
;
void
private:
Seal
()
override
;
void
SetParameters
(
const
Config
&
config
);
private:
void
private:
SetParameters
(
const
Config
&
config
);
PreprocessorPtr
preprocessor_
;
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_ptr_
;
private:
};
PreprocessorPtr
preprocessor_
;
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_ptr_
;
using
CPUKDTRNGPtr
=
std
::
shared_ptr
<
CPUKDTRNG
>
;
SPTAG
::
IndexAlgoType
index_type_
;
};
class
CPUKDTRNGIndexModel
:
public
IndexModel
{
public:
using
CPUSPTAGRNGPtr
=
std
::
shared_ptr
<
CPUSPTAGRNG
>
;
BinarySet
Serialize
()
override
;
class
CPUSPTAGRNGIndexModel
:
public
IndexModel
{
public:
void
BinarySet
Load
(
const
BinarySet
&
binary
)
override
;
Serialize
()
override
;
private:
void
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_
;
Load
(
const
BinarySet
&
binary
)
override
;
};
private:
using
CPUKDTRNGIndexModelPtr
=
std
::
shared_ptr
<
CPUKDTRNGIndexModel
>
;
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_
;
};
using
CPUSPTAGRNGIndexModelPtr
=
std
::
shared_ptr
<
CPUSPTAGRNGIndexModel
>
;
}
// namespace knowhere
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h
浏览文件 @
bfd4c6c5
...
@@ -42,6 +42,32 @@ constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
...
@@ -42,6 +42,32 @@ constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
constexpr
int64_t
DEFAULT_CANDIDATE_SISE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_CANDIDATE_SISE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NNG_K
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NNG_K
=
INVALID_VALUE
;
// SPTAG Config
constexpr
int64_t
DEFAULT_SAMPLES
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_TPTNUMBER
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_TPTLEAFSIZE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMTOPDIMENSIONTPTSPLIT
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NEIGHBORHOODSIZE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_GRAPHNEIGHBORHOODSCALE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_GRAPHCEFSCALE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_REFINEITERATIONS
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_CEF
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_MAXCHECKFORREFINEGRAPH
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMOFTHREADS
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_MAXCHECK
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS
=
INVALID_VALUE
;
// KDT Config
constexpr
int64_t
DEFAULT_KDTNUMBER
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMTOPDIMENSIONKDTSPLIT
=
INVALID_VALUE
;
// BKT Config
constexpr
int64_t
DEFAULT_BKTNUMBER
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_BKTKMEANSK
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_BKTLEAFSIZE
=
INVALID_VALUE
;
struct
IVFCfg
:
public
Cfg
{
struct
IVFCfg
:
public
Cfg
{
int64_t
nlist
=
DEFAULT_NLIST
;
int64_t
nlist
=
DEFAULT_NLIST
;
int64_t
nprobe
=
DEFAULT_NPROBE
;
int64_t
nprobe
=
DEFAULT_NPROBE
;
...
@@ -126,8 +152,57 @@ struct NSGCfg : public IVFCfg {
...
@@ -126,8 +152,57 @@ struct NSGCfg : public IVFCfg {
};
};
using
NSGConfig
=
std
::
shared_ptr
<
NSGCfg
>
;
using
NSGConfig
=
std
::
shared_ptr
<
NSGCfg
>
;
struct
KDTCfg
:
public
Cfg
{
struct
SPTAGCfg
:
public
Cfg
{
int64_t
tptnubmber
=
-
1
;
int64_t
samples
=
DEFAULT_SAMPLES
;
int64_t
tptnumber
=
DEFAULT_TPTNUMBER
;
int64_t
tptleafsize
=
DEFAULT_TPTLEAFSIZE
;
int64_t
numtopdimensiontptsplit
=
DEFAULT_NUMTOPDIMENSIONTPTSPLIT
;
int64_t
neighborhoodsize
=
DEFAULT_NEIGHBORHOODSIZE
;
int64_t
graphneighborhoodscale
=
DEFAULT_GRAPHNEIGHBORHOODSCALE
;
int64_t
graphcefscale
=
DEFAULT_GRAPHCEFSCALE
;
int64_t
refineiterations
=
DEFAULT_REFINEITERATIONS
;
int64_t
cef
=
DEFAULT_CEF
;
int64_t
maxcheckforrefinegraph
=
DEFAULT_MAXCHECKFORREFINEGRAPH
;
int64_t
numofthreads
=
DEFAULT_NUMOFTHREADS
;
int64_t
maxcheck
=
DEFAULT_MAXCHECK
;
int64_t
thresholdofnumberofcontinuousnobetterpropagation
=
DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION
;
int64_t
numberofinitialdynamicpivots
=
DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS
;
int64_t
numberofotherdynamicpivots
=
DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS
;
SPTAGCfg
()
=
default
;
bool
CheckValid
()
override
{
return
true
;
};
};
using
SPTAGConfig
=
std
::
shared_ptr
<
SPTAGCfg
>
;
struct
KDTCfg
:
public
SPTAGCfg
{
int64_t
kdtnumber
=
DEFAULT_KDTNUMBER
;
int64_t
numtopdimensionkdtsplit
=
DEFAULT_NUMTOPDIMENSIONKDTSPLIT
;
KDTCfg
()
=
default
;
bool
CheckValid
()
override
{
return
true
;
};
};
using
KDTConfig
=
std
::
shared_ptr
<
KDTCfg
>
;
struct
BKTCfg
:
public
SPTAGCfg
{
int64_t
bktnumber
=
DEFAULT_BKTNUMBER
;
int64_t
bktkmeansk
=
DEFAULT_BKTKMEANSK
;
int64_t
bktleafsize
=
DEFAULT_BKTLEAFSIZE
;
BKTCfg
()
=
default
;
bool
CheckValid
()
override
{
return
true
;
};
};
};
using
BKTConfig
=
std
::
shared_ptr
<
BKTCfg
>
;
}
// namespace knowhere
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp
已删除
100644 → 0
浏览文件 @
0aa90ea9
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <mutex>
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
namespace
knowhere
{
const
std
::
vector
<
KDTParameter
>&
KDTParameterMgr
::
GetKDTParameters
()
{
return
kdt_parameters_
;
}
KDTParameterMgr
::
KDTParameterMgr
()
{
kdt_parameters_
=
std
::
vector
<
KDTParameter
>
{
{
"KDTNumber"
,
"1"
},
{
"NumTopDimensionKDTSplit"
,
"5"
},
{
"NumSamplesKDTSplitConsideration"
,
"100"
},
{
"TPTNumber"
,
"1"
},
{
"TPTLeafSize"
,
"2000"
},
{
"NumTopDimensionTPTSplit"
,
"5"
},
{
"NeighborhoodSize"
,
"32"
},
{
"GraphNeighborhoodScale"
,
"2"
},
{
"GraphCEFScale"
,
"2"
},
{
"RefineIterations"
,
"0"
},
{
"CEF"
,
"1000"
},
{
"MaxCheckForRefineGraph"
,
"10000"
},
{
"NumberOfThreads"
,
"1"
},
{
"MaxCheck"
,
"8192"
},
{
"ThresholdOfNumberOfContinuousNoBetterPropagation"
,
"3"
},
{
"NumberOfInitialDynamicPivots"
,
"50"
},
{
"NumberOfOtherDynamicPivots"
,
"4"
},
};
}
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
0 → 100644
浏览文件 @
bfd4c6c5
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <mutex>
#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h"
namespace
knowhere
{
const
KDTConfig
&
SPTAGParameterMgr
::
GetKDTParameters
()
{
return
kdt_config_
;
}
const
BKTConfig
&
SPTAGParameterMgr
::
GetBKTParameters
()
{
return
bkt_config_
;
}
SPTAGParameterMgr
::
SPTAGParameterMgr
()
{
kdt_config_
=
std
::
make_shared
<
KDTCfg
>
();
kdt_config_
->
kdtnumber
=
1
;
kdt_config_
->
numtopdimensionkdtsplit
=
5
;
kdt_config_
->
samples
=
100
;
kdt_config_
->
tptnumber
=
1
;
kdt_config_
->
tptleafsize
=
2000
;
kdt_config_
->
numtopdimensiontptsplit
=
5
;
kdt_config_
->
neighborhoodsize
=
32
;
kdt_config_
->
graphneighborhoodscale
=
2
;
kdt_config_
->
graphcefscale
=
2
;
kdt_config_
->
refineiterations
=
0
;
kdt_config_
->
cef
=
1000
;
kdt_config_
->
maxcheckforrefinegraph
=
10000
;
kdt_config_
->
numofthreads
=
1
;
kdt_config_
->
maxcheck
=
8192
;
kdt_config_
->
thresholdofnumberofcontinuousnobetterpropagation
=
3
;
kdt_config_
->
numberofinitialdynamicpivots
=
50
;
kdt_config_
->
numberofotherdynamicpivots
=
4
;
bkt_config_
=
std
::
make_shared
<
BKTCfg
>
();
bkt_config_
->
bktnumber
=
1
;
bkt_config_
->
bktkmeansk
=
32
;
bkt_config_
->
bktleafsize
=
8
;
bkt_config_
->
samples
=
100
;
bkt_config_
->
tptnumber
=
1
;
bkt_config_
->
tptleafsize
=
2000
;
bkt_config_
->
numtopdimensiontptsplit
=
5
;
bkt_config_
->
neighborhoodsize
=
32
;
bkt_config_
->
graphneighborhoodscale
=
2
;
bkt_config_
->
graphcefscale
=
2
;
bkt_config_
->
refineiterations
=
0
;
bkt_config_
->
cef
=
1000
;
bkt_config_
->
maxcheckforrefinegraph
=
10000
;
bkt_config_
->
numofthreads
=
1
;
bkt_config_
->
maxcheck
=
8192
;
bkt_config_
->
thresholdofnumberofcontinuousnobetterpropagation
=
3
;
bkt_config_
->
numberofinitialdynamicpivots
=
50
;
bkt_config_
->
numberofotherdynamicpivots
=
4
;
}
}
// namespace knowhere
\ No newline at end of file
core/src/index/knowhere/knowhere/index/vector_index/helpers/
KDT
ParameterMgr.h
→
core/src/index/knowhere/knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.h
浏览文件 @
bfd4c6c5
...
@@ -22,31 +22,40 @@
...
@@ -22,31 +22,40 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <SPTAG/AnnService/inc/Core/Common.h>
#include "IndexParameter.h"
namespace
knowhere
{
namespace
knowhere
{
using
KDTParameter
=
std
::
pair
<
std
::
string
,
std
::
string
>
;
using
KDTConfig
=
std
::
shared_ptr
<
KDTCfg
>
;
using
BKTConfig
=
std
::
shared_ptr
<
BKTCfg
>
;
class
SPTAGParameterMgr
{
public:
const
KDTConfig
&
GetKDTParameters
();
const
BKTConfig
&
GetBKTParameters
();
class
KDTParameterMgr
{
public:
public:
static
SPTAGParameterMgr
&
const
std
::
vector
<
KDTParameter
>&
GetInstance
()
{
GetKDTParameters
();
static
SPTAGParameterMgr
instance
;
return
instance
;
}
public:
SPTAGParameterMgr
(
const
SPTAGParameterMgr
&
)
=
delete
;
static
KDTParameterMgr
&
GetInstance
()
{
static
KDTParameterMgr
instance
;
return
instance
;
}
KDTParameterMgr
(
const
KDTParameterMgr
&
)
=
delete
;
SPTAGParameterMgr
&
KDTParameterMgr
&
operator
=
(
const
SPTAGParameterMgr
&
)
=
delete
;
operator
=
(
const
KDTParameterMgr
&
)
=
delete
;
private:
private:
KDT
ParameterMgr
();
SPTAG
ParameterMgr
();
private:
private:
std
::
vector
<
KDTParameter
>
kdt_parameters_
;
KDTConfig
kdt_config_
;
};
BKTConfig
bkt_config_
;
};
}
// namespace knowhere
}
// namespace knowhere
core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h
浏览文件 @
bfd4c6c5
...
@@ -195,7 +195,7 @@ namespace SPTAG
...
@@ -195,7 +195,7 @@ namespace SPTAG
C
=
*
((
DimensionType
*
)
pDataPointsMemFile
);
C
=
*
((
DimensionType
*
)
pDataPointsMemFile
);
pDataPointsMemFile
+=
sizeof
(
DimensionType
);
pDataPointsMemFile
+=
sizeof
(
DimensionType
);
Initialize
(
R
,
C
,
(
T
*
)
pDataPointsMemFile
);
Initialize
(
R
,
C
,
(
T
*
)
pDataPointsMemFile
,
false
);
std
::
cout
<<
"Load "
<<
name
<<
" ("
<<
R
<<
", "
<<
C
<<
") Finish!"
<<
std
::
endl
;
std
::
cout
<<
"Load "
<<
name
<<
" ("
<<
R
<<
", "
<<
C
<<
") Finish!"
<<
std
::
endl
;
return
true
;
return
true
;
}
}
...
...
core/src/index/unittest/CMakeLists.txt
浏览文件 @
bfd4c6c5
...
@@ -82,17 +82,17 @@ if (NOT TARGET test_idmap)
...
@@ -82,17 +82,17 @@ if (NOT TARGET test_idmap)
endif
()
endif
()
target_link_libraries
(
test_idmap
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
target_link_libraries
(
test_idmap
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
#<
KDT
-TEST>
#<
SPTAG
-TEST>
set
(
kdt
_srcs
set
(
sptag
_srcs
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/adapter/SptagAdapter.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/adapter/SptagAdapter.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/preprocessor/Normalize.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/preprocessor/Normalize.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/helpers/
KDT
ParameterMgr.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/Index
KDT
.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/Index
SPTAG
.cpp
)
)
if
(
NOT TARGET test_
kdt
)
if
(
NOT TARGET test_
sptag
)
add_executable
(
test_
kdt test_kdt.cpp
${
kdt
_srcs
}
${
util_srcs
}
)
add_executable
(
test_
sptag test_sptag.cpp
${
sptag
_srcs
}
${
util_srcs
}
)
endif
()
endif
()
target_link_libraries
(
test_
kdt
target_link_libraries
(
test_
sptag
SPTAGLibStatic
SPTAGLibStatic
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
...
@@ -106,7 +106,7 @@ endif ()
...
@@ -106,7 +106,7 @@ endif ()
install
(
TARGETS test_ivf DESTINATION unittest
)
install
(
TARGETS test_ivf DESTINATION unittest
)
install
(
TARGETS test_idmap DESTINATION unittest
)
install
(
TARGETS test_idmap DESTINATION unittest
)
install
(
TARGETS test_
kdt
DESTINATION unittest
)
install
(
TARGETS test_
sptag
DESTINATION unittest
)
if
(
KNOWHERE_GPU_VERSION
)
if
(
KNOWHERE_GPU_VERSION
)
install
(
TARGETS test_gpuresource DESTINATION unittest
)
install
(
TARGETS test_gpuresource DESTINATION unittest
)
install
(
TARGETS test_customized_index DESTINATION unittest
)
install
(
TARGETS test_customized_index DESTINATION unittest
)
...
...
core/src/index/unittest/test_
kdt
.cpp
→
core/src/index/unittest/test_
sptag
.cpp
浏览文件 @
bfd4c6c5
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/adapter/Structure.h"
#include "knowhere/adapter/Structure.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/Index
KDT
.h"
#include "knowhere/index/vector_index/Index
SPTAG
.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "unittest/utils.h"
#include "unittest/utils.h"
...
@@ -32,28 +32,38 @@ using ::testing::Combine;
...
@@ -32,28 +32,38 @@ using ::testing::Combine;
using
::
testing
::
TestWithParam
;
using
::
testing
::
TestWithParam
;
using
::
testing
::
Values
;
using
::
testing
::
Values
;
class
KDTTest
:
public
DataGen
,
public
::
testing
::
Test
{
class
SPTAGTest
:
public
DataGen
,
public
TestWithParam
<
std
::
string
>
{
protected:
protected:
void
void
SetUp
()
override
{
SetUp
()
override
{
Generate
(
96
,
1000
,
10
);
IndexType
=
GetParam
();
index_
=
std
::
make_shared
<
knowhere
::
CPUKDTRNG
>
();
Generate
(
128
,
100
,
5
);
index_
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
IndexType
);
auto
tempconf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
if
(
IndexType
==
"KDT"
)
{
tempconf
->
tptnubmber
=
1
;
auto
tempconf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
tempconf
->
k
=
10
;
tempconf
->
tptnumber
=
1
;
conf
=
tempconf
;
tempconf
->
k
=
10
;
conf
=
tempconf
;
}
else
{
auto
tempconf
=
std
::
make_shared
<
knowhere
::
BKTCfg
>
();
tempconf
->
tptnumber
=
1
;
tempconf
->
k
=
10
;
conf
=
tempconf
;
}
Init_with_default
();
Init_with_default
();
}
}
protected:
protected:
knowhere
::
Config
conf
;
knowhere
::
Config
conf
;
std
::
shared_ptr
<
knowhere
::
CPUKDTRNG
>
index_
=
nullptr
;
std
::
shared_ptr
<
knowhere
::
CPUSPTAGRNG
>
index_
=
nullptr
;
std
::
string
IndexType
;
};
};
INSTANTIATE_TEST_CASE_P
(
SPTAGParameters
,
SPTAGTest
,
Values
(
"KDT"
,
"BKT"
));
// TODO(lxj): add test about count() and dimension()
// TODO(lxj): add test about count() and dimension()
TEST_
F
(
KDTTest
,
kdt
_basic
)
{
TEST_
P
(
SPTAGTest
,
sptag
_basic
)
{
assert
(
!
xb
.
empty
());
assert
(
!
xb
.
empty
());
auto
preprocessor
=
index_
->
BuildPreprocessor
(
base_dataset
,
conf
);
auto
preprocessor
=
index_
->
BuildPreprocessor
(
base_dataset
,
conf
);
...
@@ -66,8 +76,8 @@ TEST_F(KDTTest, kdt_basic) {
...
@@ -66,8 +76,8 @@ TEST_F(KDTTest, kdt_basic) {
AssertAnns
(
result
,
nq
,
k
);
AssertAnns
(
result
,
nq
,
k
);
{
{
//
auto ids = result->array()[0];
//auto ids = result->array()[0];
//
auto dists = result->array()[1];
//auto dists = result->array()[1];
auto
ids
=
result
->
ids
();
auto
ids
=
result
->
ids
();
auto
dists
=
result
->
dist
();
auto
dists
=
result
->
dist
();
...
@@ -75,10 +85,10 @@ TEST_F(KDTTest, kdt_basic) {
...
@@ -75,10 +85,10 @@ TEST_F(KDTTest, kdt_basic) {
std
::
stringstream
ss_dist
;
std
::
stringstream
ss_dist
;
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
//ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
//ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
ss_id
<<
*
((
int64_t
*
)(
ids
)
+
i
*
k
+
j
)
<<
" "
;
ss_id
<<
*
((
int64_t
*
)(
ids
)
+
i
*
k
+
j
)
<<
" "
;
ss_dist
<<
*
((
float
*
)(
dists
)
+
i
*
k
+
j
)
<<
" "
;
ss_dist
<<
*
((
float
*
)(
dists
)
+
i
*
k
+
j
)
<<
" "
;
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
}
}
ss_id
<<
std
::
endl
;
ss_id
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
...
@@ -88,57 +98,57 @@ TEST_F(KDTTest, kdt_basic) {
...
@@ -88,57 +98,57 @@ TEST_F(KDTTest, kdt_basic) {
}
}
}
}
// TODO(zirui): enable test
TEST_P
(
SPTAGTest
,
sptag_serialize
)
{
// TEST_F(KDTTest, kdt_serialize) {
assert
(
!
xb
.
empty
());
// assert(!xb.empty());
//
auto
preprocessor
=
index_
->
BuildPreprocessor
(
base_dataset
,
conf
);
// auto preprocessor = index_->BuildPreprocessor(base_dataset, conf
);
index_
->
set_preprocessor
(
preprocessor
);
// index_->set_preprocessor(preprocessor);
//
auto
model
=
index_
->
Train
(
base_dataset
,
conf
);
// auto model = index_->Train(base_dataset, conf);
// //
index_->Add(base_dataset, conf);
index_
->
Add
(
base_dataset
,
conf
);
//
auto binaryset = index_->Serialize();
auto
binaryset
=
index_
->
Serialize
();
// auto new_index = std::make_shared<knowhere::CPUKDTRNG>(
);
auto
new_index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
IndexType
);
//
new_index->Load(binaryset);
new_index
->
Load
(
binaryset
);
//
auto result = new_index->Search(query_dataset, conf);
auto
result
=
new_index
->
Search
(
query_dataset
,
conf
);
//
AssertAnns(result, nq, k);
AssertAnns
(
result
,
nq
,
k
);
//
PrintResult(result, nq, k);
PrintResult
(
result
,
nq
,
k
);
//
ASSERT_EQ(new_index->Count(), nb);
ASSERT_EQ
(
new_index
->
Count
(),
nb
);
//
ASSERT_EQ(new_index->Dimension(), dim);
ASSERT_EQ
(
new_index
->
Dimension
(),
dim
);
//
ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException);
//
ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException);
//
ASSERT_NO_THROW({ new_index->Seal(); });
//
ASSERT_NO_THROW({ new_index->Seal(); });
//
//
{
{
//
int fileno = 0;
int
fileno
=
0
;
// const std::string& base_name = "/tmp/kdt
_serialize_test_bin_";
const
std
::
string
&
base_name
=
"/tmp/sptag
_serialize_test_bin_"
;
//
std::vector<std::string> filename_list;
std
::
vector
<
std
::
string
>
filename_list
;
//
std::vector<std::pair<std::string, size_t>> meta_list;
std
::
vector
<
std
::
pair
<
std
::
string
,
size_t
>>
meta_list
;
//
for (auto& iter : binaryset.binary_map_) {
for
(
auto
&
iter
:
binaryset
.
binary_map_
)
{
//
const std::string& filename = base_name + std::to_string(fileno);
const
std
::
string
&
filename
=
base_name
+
std
::
to_string
(
fileno
);
//
FileIOWriter writer(filename);
FileIOWriter
writer
(
filename
);
//
writer(iter.second->data.get(), iter.second->size);
writer
(
iter
.
second
->
data
.
get
(),
iter
.
second
->
size
);
//
//
meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
meta_list
.
emplace_back
(
std
::
make_pair
(
iter
.
first
,
iter
.
second
->
size
));
//
filename_list.push_back(filename);
filename_list
.
push_back
(
filename
);
//
++fileno;
++
fileno
;
//
}
}
//
//
knowhere::BinarySet load_data_list;
knowhere
::
BinarySet
load_data_list
;
//
for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
for
(
int
i
=
0
;
i
<
filename_list
.
size
()
&&
i
<
meta_list
.
size
();
++
i
)
{
//
auto bin_size = meta_list[i].second;
auto
bin_size
=
meta_list
[
i
].
second
;
//
FileIOReader reader(filename_list[i]);
FileIOReader
reader
(
filename_list
[
i
]);
//
//
auto load_data = new uint8_t[bin_size];
auto
load_data
=
new
uint8_t
[
bin_size
];
//
reader(load_data, bin_size);
reader
(
load_data
,
bin_size
);
//
auto data = std::make_shared<uint8_t>();
auto
data
=
std
::
make_shared
<
uint8_t
>
();
//
data.reset(load_data);
data
.
reset
(
load_data
);
//
load_data_list.Append(meta_list[i].first, data, bin_size);
load_data_list
.
Append
(
meta_list
[
i
].
first
,
data
,
bin_size
);
//
}
}
//
// auto new_index = std::make_shared<knowhere::CPUKDTRNG>(
);
auto
new_index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
IndexType
);
//
new_index->Load(load_data_list);
new_index
->
Load
(
load_data_list
);
//
auto result = new_index->Search(query_dataset, conf);
auto
result
=
new_index
->
Search
(
query_dataset
,
conf
);
//
AssertAnns(result, nq, k);
AssertAnns
(
result
,
nq
,
k
);
//
PrintResult(result, nq, k);
PrintResult
(
result
,
nq
,
k
);
//
}
}
//
}
}
core/src/index/unittest/utils.cpp
浏览文件 @
bfd4c6c5
...
@@ -153,22 +153,24 @@ void
...
@@ -153,22 +153,24 @@ void
AssertAnns
(
const
knowhere
::
DatasetPtr
&
result
,
const
int
&
nq
,
const
int
&
k
)
{
AssertAnns
(
const
knowhere
::
DatasetPtr
&
result
,
const
int
&
nq
,
const
int
&
k
)
{
auto
ids
=
result
->
ids
();
auto
ids
=
result
->
ids
();
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
EXPECT_EQ
(
i
,
*
((
int64_t
*
)(
ids
)
+
i
*
k
));
EXPECT_EQ
(
i
,
*
((
int64_t
*
)(
ids
)
+
i
*
k
));
// EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
// EXPECT_EQ(i, *(ids->data()->GetValues<int64_t>(1, i * k)));
}
}
}
}
void
void
PrintResult
(
const
knowhere
::
DatasetPtr
&
result
,
const
int
&
nq
,
const
int
&
k
)
{
PrintResult
(
const
knowhere
::
DatasetPtr
&
result
,
const
int
&
nq
,
const
int
&
k
)
{
auto
ids
=
result
->
array
()[
0
]
;
auto
ids
=
result
->
ids
()
;
auto
dists
=
result
->
array
()[
1
]
;
auto
dists
=
result
->
dist
()
;
std
::
stringstream
ss_id
;
std
::
stringstream
ss_id
;
std
::
stringstream
ss_dist
;
std
::
stringstream
ss_dist
;
for
(
auto
i
=
0
;
i
<
10
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
ss_id
<<
*
(
ids
->
data
()
->
GetValues
<
int64_t
>
(
1
,
i
*
k
+
j
))
<<
" "
;
//ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
ss_dist
<<
*
(
dists
->
data
()
->
GetValues
<
float
>
(
1
,
i
*
k
+
j
))
<<
" "
;
//ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
ss_id
<<
*
((
int64_t
*
)(
ids
)
+
i
*
k
+
j
)
<<
" "
;
ss_dist
<<
*
((
float
*
)(
dists
)
+
i
*
k
+
j
)
<<
" "
;
}
}
ss_id
<<
std
::
endl
;
ss_id
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
...
...
core/src/wrapper/ConfAdapter.cpp
浏览文件 @
bfd4c6c5
...
@@ -204,5 +204,35 @@ NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type)
...
@@ -204,5 +204,35 @@ NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type)
return
conf
;
return
conf
;
}
}
knowhere
::
Config
SPTAGKDTConfAdapter
::
Match
(
const
TempMetaConf
&
metaconf
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
conf
->
d
=
metaconf
.
dim
;
conf
->
metric_type
=
metaconf
.
metric_type
;
return
conf
;
}
knowhere
::
Config
SPTAGKDTConfAdapter
::
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
conf
->
k
=
metaconf
.
k
;
return
conf
;
}
knowhere
::
Config
SPTAGBKTConfAdapter
::
Match
(
const
TempMetaConf
&
metaconf
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
BKTCfg
>
();
conf
->
d
=
metaconf
.
dim
;
conf
->
metric_type
=
metaconf
.
metric_type
;
return
conf
;
}
knowhere
::
Config
SPTAGBKTConfAdapter
::
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
BKTCfg
>
();
conf
->
k
=
metaconf
.
k
;
return
conf
;
}
}
// namespace engine
}
// namespace engine
}
// namespace milvus
}
// namespace milvus
core/src/wrapper/ConfAdapter.h
浏览文件 @
bfd4c6c5
...
@@ -97,5 +97,23 @@ class NSGConfAdapter : public IVFConfAdapter {
...
@@ -97,5 +97,23 @@ class NSGConfAdapter : public IVFConfAdapter {
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
final
;
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
final
;
};
};
class
SPTAGKDTConfAdapter
:
public
ConfAdapter
{
public:
knowhere
::
Config
Match
(
const
TempMetaConf
&
metaconf
)
override
;
knowhere
::
Config
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
override
;
};
class
SPTAGBKTConfAdapter
:
public
ConfAdapter
{
public:
knowhere
::
Config
Match
(
const
TempMetaConf
&
metaconf
)
override
;
knowhere
::
Config
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
override
;
};
}
// namespace engine
}
// namespace engine
}
// namespace milvus
}
// namespace milvus
core/src/wrapper/ConfAdapterMgr.cpp
浏览文件 @
bfd4c6c5
...
@@ -56,6 +56,9 @@ AdapterMgr::RegisterAdapter() {
...
@@ -56,6 +56,9 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER
(
IVFPQConfAdapter
,
IndexType
::
FAISS_IVFPQ_MIX
,
ivfpq_mix
);
REGISTER_CONF_ADAPTER
(
IVFPQConfAdapter
,
IndexType
::
FAISS_IVFPQ_MIX
,
ivfpq_mix
);
REGISTER_CONF_ADAPTER
(
NSGConfAdapter
,
IndexType
::
NSG_MIX
,
nsg_mix
);
REGISTER_CONF_ADAPTER
(
NSGConfAdapter
,
IndexType
::
NSG_MIX
,
nsg_mix
);
REGISTER_CONF_ADAPTER
(
SPTAGKDTConfAdapter
,
IndexType
::
SPTAG_KDT_RNT_CPU
,
sptag_kdt
);
REGISTER_CONF_ADAPTER
(
SPTAGBKTConfAdapter
,
IndexType
::
SPTAG_BKT_RNT_CPU
,
sptag_bkt
);
}
}
}
// namespace engine
}
// namespace engine
...
...
core/src/wrapper/VecIndex.cpp
浏览文件 @
bfd4c6c5
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/Index
KDT
.h"
#include "knowhere/index/vector_index/Index
SPTAG
.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "utils/Log.h"
#include "utils/Log.h"
...
@@ -128,7 +128,11 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
...
@@ -128,7 +128,11 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
break
;
break
;
}
}
case
IndexType
::
SPTAG_KDT_RNT_CPU
:
{
case
IndexType
::
SPTAG_KDT_RNT_CPU
:
{
index
=
std
::
make_shared
<
knowhere
::
CPUKDTRNG
>
();
index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
"KDT"
);
break
;
}
case
IndexType
::
SPTAG_BKT_RNT_CPU
:
{
index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
"BKT"
);
break
;
break
;
}
}
case
IndexType
::
FAISS_IVFSQ8_CPU
:
{
case
IndexType
::
FAISS_IVFSQ8_CPU
:
{
...
...
core/src/wrapper/VecIndex.h
浏览文件 @
bfd4c6c5
...
@@ -49,6 +49,7 @@ enum class IndexType {
...
@@ -49,6 +49,7 @@ enum class IndexType {
FAISS_IVFSQ8_HYBRID
,
// only support build on gpu.
FAISS_IVFSQ8_HYBRID
,
// only support build on gpu.
NSG_MIX
,
NSG_MIX
,
FAISS_IVFPQ_MIX
,
FAISS_IVFPQ_MIX
,
SPTAG_BKT_RNT_CPU
,
};
};
class
VecIndex
;
class
VecIndex
;
...
@@ -139,6 +140,9 @@ write_index(VecIndexPtr index, const std::string& location);
...
@@ -139,6 +140,9 @@ write_index(VecIndexPtr index, const std::string& location);
extern
VecIndexPtr
extern
VecIndexPtr
read_index
(
const
std
::
string
&
location
);
read_index
(
const
std
::
string
&
location
);
VecIndexPtr
read_index
(
const
std
::
string
&
location
,
knowhere
::
BinarySet
&
index_binary
);
extern
VecIndexPtr
extern
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
=
Config
());
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
=
Config
());
...
...
core/unittest/wrapper/test_wrapper.cpp
浏览文件 @
bfd4c6c5
...
@@ -29,15 +29,16 @@
...
@@ -29,15 +29,16 @@
INITIALIZE_EASYLOGGINGPP
INITIALIZE_EASYLOGGINGPP
using
::
testing
::
Combine
;
using
::
testing
::
TestWithParam
;
using
::
testing
::
TestWithParam
;
using
::
testing
::
Values
;
using
::
testing
::
Values
;
using
::
testing
::
Combine
;
class
KnowhereWrapperTest
class
KnowhereWrapperTest
:
public
DataGenBase
,
:
public
DataGenBase
,
public
TestWithParam
<::
std
::
tuple
<
milvus
::
engine
::
IndexType
,
std
::
string
,
int
,
int
,
int
,
int
>>
{
public
TestWithParam
<::
std
::
tuple
<
milvus
::
engine
::
IndexType
,
std
::
string
,
int
,
int
,
int
,
int
>>
{
protected:
protected:
void
SetUp
()
override
{
void
SetUp
()
override
{
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
InitDevice
(
DEVICEID
,
PINMEM
,
TEMPMEM
,
RESNUM
);
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
InitDevice
(
DEVICEID
,
PINMEM
,
TEMPMEM
,
RESNUM
);
#endif
#endif
...
@@ -57,12 +58,13 @@ class KnowhereWrapperTest
...
@@ -57,12 +58,13 @@ class KnowhereWrapperTest
conf
=
ParamGenerator
::
GetInstance
().
GenBuild
(
index_type
,
tempconf
);
conf
=
ParamGenerator
::
GetInstance
().
GenBuild
(
index_type
,
tempconf
);
searchconf
=
ParamGenerator
::
GetInstance
().
GenSearchConf
(
index_type
,
tempconf
);
searchconf
=
ParamGenerator
::
GetInstance
().
GenSearchConf
(
index_type
,
tempconf
);
// conf->k = k;
// conf->k = k;
// conf->d = dim;
// conf->d = dim;
// conf->gpu_id = DEVICEID;
// conf->gpu_id = DEVICEID;
}
}
void
TearDown
()
override
{
void
TearDown
()
override
{
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
Free
();
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
Free
();
#endif
#endif
...
@@ -75,22 +77,21 @@ class KnowhereWrapperTest
...
@@ -75,22 +77,21 @@ class KnowhereWrapperTest
knowhere
::
Config
searchconf
;
knowhere
::
Config
searchconf
;
};
};
INSTANTIATE_TEST_CASE_P
(
WrapperParam
,
KnowhereWrapperTest
,
INSTANTIATE_TEST_CASE_P
(
Values
(
WrapperParam
,
KnowhereWrapperTest
,
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
Values
(
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
// std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_GPU, "Default", DIM, NB,
// 10, 10),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_MIX
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_MIX
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFPQ_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFPQ_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
// std::make_tuple(milvus::engine::IndexType::NSG_MIX, "Default", 128, 250000, 10, 10),
// std::make_tuple(IndexType::NSG_MIX, "Default", 128, 250000, 10, 10),
#endif
#endif
// std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 250000, 10, 10),
// std::make_tuple(milvus::engine::IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 100, 10, 10),
// std::make_tuple(milvus::engine::IndexType::SPTAG_BKT_RNT_CPU, "Default", 128, 100, 10, 10),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IDMAP
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IDMAP
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_CPU
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_CPU
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_CPU
,
"Default"
,
DIM
,
NB
,
10
,
10
)));
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_CPU
,
"Default"
,
DIM
,
NB
,
10
,
10
)));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录