Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
milvus
提交
e01eb29f
milvus
项目概览
BaiXuePrincess
/
milvus
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
milvus
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e01eb29f
编写于
11月 20, 2019
作者:
T
Tinkerrr
提交者:
GitHub
11月 20, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #438 from tinkerlin/0.6.0-#227
0.6.0 #227
上级
24e05a87
863cc5db
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
794 addition
and
235 deletion
+794
-235
CHANGELOG.md
CHANGELOG.md
+1
-0
core/src/db/engine/ExecutionEngine.h
core/src/db/engine/ExecutionEngine.h
+3
-1
core/src/db/engine/ExecutionEngineImpl.cpp
core/src/db/engine/ExecutionEngineImpl.cpp
+8
-0
core/src/index/knowhere/CMakeLists.txt
core/src/index/knowhere/CMakeLists.txt
+2
-2
core/src/index/knowhere/knowhere/index/vector_index/IndexKDT.cpp
...c/index/knowhere/knowhere/index/vector_index/IndexKDT.cpp
+0
-180
core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
...index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
+348
-0
core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h
...c/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h
+14
-8
core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h
...here/knowhere/index/vector_index/helpers/IndexParameter.h
+77
-2
core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
...knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp
+75
-0
core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.h
...e/knowhere/index/vector_index/helpers/SPTAGParameterMgr.h
+19
-10
core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h
...dex/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h
+1
-1
core/src/index/unittest/CMakeLists.txt
core/src/index/unittest/CMakeLists.txt
+8
-8
core/src/index/unittest/test_sptag.cpp
core/src/index/unittest/test_sptag.cpp
+154
-0
core/src/index/unittest/utils.cpp
core/src/index/unittest/utils.cpp
+7
-5
core/src/wrapper/ConfAdapter.cpp
core/src/wrapper/ConfAdapter.cpp
+30
-0
core/src/wrapper/ConfAdapter.h
core/src/wrapper/ConfAdapter.h
+18
-0
core/src/wrapper/ConfAdapterMgr.cpp
core/src/wrapper/ConfAdapterMgr.cpp
+3
-0
core/src/wrapper/VecImpl.cpp
core/src/wrapper/VecImpl.cpp
+1
-1
core/src/wrapper/VecIndex.cpp
core/src/wrapper/VecIndex.cpp
+6
-2
core/src/wrapper/VecIndex.h
core/src/wrapper/VecIndex.h
+4
-0
core/unittest/wrapper/test_wrapper.cpp
core/unittest/wrapper/test_wrapper.cpp
+15
-15
未找到文件。
CHANGELOG.md
浏览文件 @
e01eb29f
...
@@ -23,6 +23,7 @@ Please mark all change in change log and use the ticket from JIRA.
...
@@ -23,6 +23,7 @@ Please mark all change in change log and use the ticket from JIRA.
-
\#
77 - Support table partition
-
\#
77 - Support table partition
-
\#
127 - Support new Index type IVFPQ
-
\#
127 - Support new Index type IVFPQ
-
\#
226 - Experimental shards middleware for Milvus
-
\#
226 - Experimental shards middleware for Milvus
-
\#
227 - Support new index types SPTAG-KDT and SPTAG-BKT
-
\#
346 - Support build index with multiple gpu
-
\#
346 - Support build index with multiple gpu
## Improvement
## Improvement
...
...
core/src/db/engine/ExecutionEngine.h
浏览文件 @
e01eb29f
...
@@ -35,7 +35,9 @@ enum class EngineType {
...
@@ -35,7 +35,9 @@ enum class EngineType {
NSG_MIX
,
NSG_MIX
,
FAISS_IVFSQ8H
,
FAISS_IVFSQ8H
,
FAISS_PQ
,
FAISS_PQ
,
MAX_VALUE
=
FAISS_PQ
,
SPTAG_KDT
,
SPTAG_BKT
,
MAX_VALUE
=
SPTAG_BKT
,
};
};
enum
class
MetricType
{
enum
class
MetricType
{
...
...
core/src/db/engine/ExecutionEngineImpl.cpp
浏览文件 @
e01eb29f
...
@@ -124,6 +124,14 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
...
@@ -124,6 +124,14 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
#endif
#endif
break
;
break
;
}
}
case
EngineType
::
SPTAG_KDT
:
{
index
=
GetVecIndexFactory
(
IndexType
::
SPTAG_KDT_RNT_CPU
);
break
;
}
case
EngineType
::
SPTAG_BKT
:
{
index
=
GetVecIndexFactory
(
IndexType
::
SPTAG_BKT_RNT_CPU
);
break
;
}
default:
{
default:
{
ENGINE_LOG_ERROR
<<
"Unsupported index type"
;
ENGINE_LOG_ERROR
<<
"Unsupported index type"
;
return
nullptr
;
return
nullptr
;
...
...
core/src/index/knowhere/CMakeLists.txt
浏览文件 @
e01eb29f
...
@@ -30,10 +30,10 @@ set(external_srcs
...
@@ -30,10 +30,10 @@ set(external_srcs
set
(
index_srcs
set
(
index_srcs
knowhere/index/preprocessor/Normalize.cpp
knowhere/index/preprocessor/Normalize.cpp
knowhere/index/vector_index/Index
KDT
.cpp
knowhere/index/vector_index/Index
SPTAG
.cpp
knowhere/index/vector_index/IndexIDMAP.cpp
knowhere/index/vector_index/IndexIDMAP.cpp
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/IndexIVF.cpp
knowhere/index/vector_index/helpers/
KDT
ParameterMgr.cpp
knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.cpp
knowhere/index/vector_index/IndexNSG.cpp
knowhere/index/vector_index/IndexNSG.cpp
knowhere/index/vector_index/nsg/NSG.cpp
knowhere/index/vector_index/nsg/NSG.cpp
knowhere/index/vector_index/nsg/NSGIO.cpp
knowhere/index/vector_index/nsg/NSGIO.cpp
...
...
core/src/index/knowhere/knowhere/index/vector_index/IndexKDT.cpp
已删除
100644 → 0
浏览文件 @
24e05a87
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <SPTAG/AnnService/inc/Core/Common.h>
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
#include <sstream>
#include <vector>
#undef mkdir
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
//#include "knowhere/index/preprocessor/normalize.h"
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h"
namespace
knowhere
{
BinarySet
CPUKDTRNG
::
Serialize
()
{
std
::
vector
<
void
*>
index_blobs
;
std
::
vector
<
int64_t
>
index_len
;
// TODO(zirui): dev
// index_ptr_->SaveIndexToMemory(index_blobs, index_len);
BinarySet
binary_set
;
//
// auto sample = std::make_shared<uint8_t>();
// sample.reset(static_cast<uint8_t*>(index_blobs[0]));
// auto tree = std::make_shared<uint8_t>();
// tree.reset(static_cast<uint8_t*>(index_blobs[1]));
// auto graph = std::make_shared<uint8_t>();
// graph.reset(static_cast<uint8_t*>(index_blobs[2]));
// auto metadata = std::make_shared<uint8_t>();
// metadata.reset(static_cast<uint8_t*>(index_blobs[3]));
//
// binary_set.Append("samples", sample, index_len[0]);
// binary_set.Append("tree", tree, index_len[1]);
// binary_set.Append("graph", graph, index_len[2]);
// binary_set.Append("metadata", metadata, index_len[3]);
return
binary_set
;
}
void
CPUKDTRNG
::
Load
(
const
BinarySet
&
binary_set
)
{
// TODO(zirui): dev
// std::vector<void*> index_blobs;
//
// auto samples = binary_set.GetByName("samples");
// index_blobs.push_back(samples->data.get());
//
// auto tree = binary_set.GetByName("tree");
// index_blobs.push_back(tree->data.get());
//
// auto graph = binary_set.GetByName("graph");
// index_blobs.push_back(graph->data.get());
//
// auto metadata = binary_set.GetByName("metadata");
// index_blobs.push_back(metadata->data.get());
//
// index_ptr_->LoadIndexFromMemory(index_blobs);
}
// PreprocessorPtr
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
// return std::make_shared<NormalizePreprocessor>();
//}
IndexModelPtr
CPUKDTRNG
::
Train
(
const
DatasetPtr
&
origin
,
const
Config
&
train_config
)
{
SetParameters
(
train_config
);
DatasetPtr
dataset
=
origin
->
Clone
();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto
vectorset
=
ConvertToVectorSet
(
dataset
);
auto
metaset
=
ConvertToMetadataSet
(
dataset
);
index_ptr_
->
BuildIndex
(
vectorset
,
metaset
);
// TODO: return IndexModelPtr
return
nullptr
;
}
void
CPUKDTRNG
::
Add
(
const
DatasetPtr
&
origin
,
const
Config
&
add_config
)
{
SetParameters
(
add_config
);
DatasetPtr
dataset
=
origin
->
Clone
();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto
vectorset
=
ConvertToVectorSet
(
dataset
);
auto
metaset
=
ConvertToMetadataSet
(
dataset
);
index_ptr_
->
AddIndex
(
vectorset
,
metaset
);
}
void
CPUKDTRNG
::
SetParameters
(
const
Config
&
config
)
{
for
(
auto
&
para
:
KDTParameterMgr
::
GetInstance
().
GetKDTParameters
())
{
// auto value = config.get_with_default(para.first, para.second);
index_ptr_
->
SetParameter
(
para
.
first
,
para
.
second
);
}
}
DatasetPtr
CPUKDTRNG
::
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
SetParameters
(
config
);
auto
tensor
=
dataset
->
tensor
()[
0
];
auto
p
=
(
float
*
)
tensor
->
raw_mutable_data
();
for
(
auto
i
=
0
;
i
<
10
;
++
i
)
{
for
(
auto
j
=
0
;
j
<
10
;
++
j
)
{
std
::
cout
<<
p
[
i
*
10
+
j
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
vector
<
SPTAG
::
QueryResult
>
query_results
=
ConvertToQueryResult
(
dataset
,
config
);
#pragma omp parallel for
for
(
auto
i
=
0
;
i
<
query_results
.
size
();
++
i
)
{
auto
target
=
(
float
*
)
query_results
[
i
].
GetTarget
();
std
::
cout
<<
target
[
0
]
<<
", "
<<
target
[
1
]
<<
", "
<<
target
[
2
]
<<
std
::
endl
;
index_ptr_
->
SearchIndex
(
query_results
[
i
]);
}
return
ConvertToDataset
(
query_results
);
}
int64_t
CPUKDTRNG
::
Count
()
{
index_ptr_
->
GetNumSamples
();
}
int64_t
CPUKDTRNG
::
Dimension
()
{
index_ptr_
->
GetFeatureDim
();
}
VectorIndexPtr
CPUKDTRNG
::
Clone
()
{
KNOWHERE_THROW_MSG
(
"not support"
);
}
void
CPUKDTRNG
::
Seal
()
{
// do nothing
}
// TODO(linxj):
BinarySet
CPUKDTRNGIndexModel
::
Serialize
()
{
}
void
CPUKDTRNGIndexModel
::
Load
(
const
BinarySet
&
binary
)
{
}
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp
0 → 100644
浏览文件 @
e01eb29f
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <SPTAG/AnnService/inc/Core/Common.h>
#include <SPTAG/AnnService/inc/Core/VectorSet.h>
#include <SPTAG/AnnService/inc/Server/QueryParser.h>
#include <array>
#include <sstream>
#include <vector>
#undef mkdir
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexSPTAG.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h"
namespace
knowhere
{
CPUSPTAGRNG
::
CPUSPTAGRNG
(
const
std
::
string
&
IndexType
)
{
if
(
IndexType
==
"KDT"
)
{
index_ptr_
=
SPTAG
::
VectorIndex
::
CreateInstance
(
SPTAG
::
IndexAlgoType
::
KDT
,
SPTAG
::
VectorValueType
::
Float
);
index_ptr_
->
SetParameter
(
"DistCalcMethod"
,
"L2"
);
index_type_
=
SPTAG
::
IndexAlgoType
::
KDT
;
}
else
{
index_ptr_
=
SPTAG
::
VectorIndex
::
CreateInstance
(
SPTAG
::
IndexAlgoType
::
BKT
,
SPTAG
::
VectorValueType
::
Float
);
index_ptr_
->
SetParameter
(
"DistCalcMethod"
,
"L2"
);
index_type_
=
SPTAG
::
IndexAlgoType
::
BKT
;
}
}
BinarySet
CPUSPTAGRNG
::
Serialize
()
{
std
::
string
index_config
;
std
::
vector
<
SPTAG
::
ByteArray
>
index_blobs
;
std
::
shared_ptr
<
std
::
vector
<
std
::
uint64_t
>>
buffersize
=
index_ptr_
->
CalculateBufferSize
();
std
::
vector
<
char
*>
res
(
buffersize
->
size
()
+
1
);
for
(
uint64_t
i
=
1
;
i
<
res
.
size
();
i
++
)
{
res
[
i
]
=
new
char
[
buffersize
->
at
(
i
-
1
)];
auto
ptr
=
&
res
[
i
][
0
];
index_blobs
.
emplace_back
(
SPTAG
::
ByteArray
((
std
::
uint8_t
*
)
ptr
,
buffersize
->
at
(
i
-
1
),
false
));
}
index_ptr_
->
SaveIndex
(
index_config
,
index_blobs
);
size_t
length
=
index_config
.
length
();
char
*
cstr
=
new
char
[
length
];
snprintf
(
cstr
,
length
,
"%s"
,
index_config
.
c_str
());
BinarySet
binary_set
;
auto
sample
=
std
::
make_shared
<
uint8_t
>
();
sample
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
0
].
Data
()));
auto
tree
=
std
::
make_shared
<
uint8_t
>
();
tree
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
1
].
Data
()));
auto
graph
=
std
::
make_shared
<
uint8_t
>
();
graph
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
2
].
Data
()));
auto
deleteid
=
std
::
make_shared
<
uint8_t
>
();
deleteid
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
3
].
Data
()));
auto
metadata1
=
std
::
make_shared
<
uint8_t
>
();
metadata1
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
4
].
Data
()));
auto
metadata2
=
std
::
make_shared
<
uint8_t
>
();
metadata2
.
reset
(
static_cast
<
uint8_t
*>
(
index_blobs
[
5
].
Data
()));
auto
config
=
std
::
make_shared
<
uint8_t
>
();
config
.
reset
(
static_cast
<
uint8_t
*>
((
void
*
)
cstr
));
binary_set
.
Append
(
"samples"
,
sample
,
index_blobs
[
0
].
Length
());
binary_set
.
Append
(
"tree"
,
tree
,
index_blobs
[
1
].
Length
());
binary_set
.
Append
(
"deleteid"
,
deleteid
,
index_blobs
[
3
].
Length
());
binary_set
.
Append
(
"metadata1"
,
metadata1
,
index_blobs
[
4
].
Length
());
binary_set
.
Append
(
"metadata2"
,
metadata2
,
index_blobs
[
5
].
Length
());
binary_set
.
Append
(
"config"
,
config
,
length
);
binary_set
.
Append
(
"graph"
,
graph
,
index_blobs
[
2
].
Length
());
// MemoryIOWriter writer;
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// len = index_blobs[i].Length();
// assert(len != 0);
// writer(&len, sizeof(size_t), 1);
// writer(index_blobs[i].Data(), len, 1);
// len = 0;
// }
// writer(&length, sizeof(size_t), 1);
// writer(cstr, length, 1);
// auto data = std::make_shared<uint8_t>();
// data.reset(writer.data_);
// BinarySet binary_set;
// binary_set.Append("sptag", data, writer.total);
// MemoryIOWriter writer;
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// if (i == 2) continue;
// len = index_blobs[i].Length();
// assert(len != 0);
// writer(&len, sizeof(size_t), 1);
// writer(index_blobs[i].Data(), len, 1);
// len = 0;
// }
// writer(&length, sizeof(size_t), 1);
// writer(cstr, length, 1);
// auto data = std::make_shared<uint8_t>();
// data.reset(writer.data_);
// BinarySet binary_set;
// binary_set.Append("sptag", data, writer.total);
// auto graph = std::make_shared<uint8_t>();
// graph.reset(static_cast<uint8_t*>(index_blobs[2].Data()));
// binary_set.Append("graph", graph, index_blobs[2].Length());
return
binary_set
;
}
void
CPUSPTAGRNG
::
Load
(
const
BinarySet
&
binary_set
)
{
std
::
string
index_config
;
std
::
vector
<
SPTAG
::
ByteArray
>
index_blobs
;
auto
samples
=
binary_set
.
GetByName
(
"samples"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
samples
->
data
.
get
(),
samples
->
size
,
false
));
auto
tree
=
binary_set
.
GetByName
(
"tree"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
tree
->
data
.
get
(),
tree
->
size
,
false
));
auto
graph
=
binary_set
.
GetByName
(
"graph"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
graph
->
data
.
get
(),
graph
->
size
,
false
));
auto
deleteid
=
binary_set
.
GetByName
(
"deleteid"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
deleteid
->
data
.
get
(),
deleteid
->
size
,
false
));
auto
metadata1
=
binary_set
.
GetByName
(
"metadata1"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
metadata1
->
data
.
get
(),
metadata1
->
size
,
false
));
auto
metadata2
=
binary_set
.
GetByName
(
"metadata2"
);
index_blobs
.
push_back
(
SPTAG
::
ByteArray
(
metadata2
->
data
.
get
(),
metadata2
->
size
,
false
));
auto
config
=
binary_set
.
GetByName
(
"config"
);
index_config
=
reinterpret_cast
<
char
*>
(
config
->
data
.
get
());
// std::vector<SPTAG::ByteArray> index_blobs;
// auto data = binary_set.GetByName("sptag");
// MemoryIOReader reader;
// reader.total = data->size;
// reader.data_ = data->data.get();
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto binary = new uint8_t[len];
// reader(binary, len, 1);
// index_blobs.emplace_back(SPTAG::ByteArray(binary, len, true));
// len = 0;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto config = new char[len];
// reader(config, len, 1);
// std::string index_config = config;
// delete[] config;
// std::vector<SPTAG::ByteArray> index_blobs;
// auto data = binary_set.GetByName("sptag");
// MemoryIOReader reader;
// reader.total = data->size;
// reader.data_ = data->data.get();
// size_t len = 0;
// for (int i = 0; i < 6; ++i) {
// if (i == 2) {
// auto graph = binary_set.GetByName("graph");
// index_blobs.emplace_back(SPTAG::ByteArray(graph->data.get(), graph->size, false));
// continue;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto binary = new uint8_t[len];
// reader(binary, len, 1);
// index_blobs.emplace_back(SPTAG::ByteArray(binary, len, true));
// len = 0;
// }
// reader(&len, sizeof(size_t), 1);
// assert(len != 0);
// auto config = new char[len];
// reader(config, len, 1);
// std::string index_config = config;
// delete[] config;
index_ptr_
->
LoadIndex
(
index_config
,
index_blobs
);
}
// PreprocessorPtr
// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) {
// return std::make_shared<NormalizePreprocessor>();
//}
IndexModelPtr
CPUSPTAGRNG
::
Train
(
const
DatasetPtr
&
origin
,
const
Config
&
train_config
)
{
SetParameters
(
train_config
);
DatasetPtr
dataset
=
origin
->
Clone
();
// if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// && preprocessor_) {
// preprocessor_->Preprocess(dataset);
//}
auto
vectorset
=
ConvertToVectorSet
(
dataset
);
auto
metaset
=
ConvertToMetadataSet
(
dataset
);
index_ptr_
->
BuildIndex
(
vectorset
,
metaset
);
// TODO: return IndexModelPtr
return
nullptr
;
}
void
CPUSPTAGRNG
::
Add
(
const
DatasetPtr
&
origin
,
const
Config
&
add_config
)
{
// SetParameters(add_config);
// DatasetPtr dataset = origin->Clone();
//
// // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine
// // && preprocessor_) {
// // preprocessor_->Preprocess(dataset);
// //}
//
// auto vectorset = ConvertToVectorSet(dataset);
// auto metaset = ConvertToMetadataSet(dataset);
// index_ptr_->AddIndex(vectorset, metaset);
}
void
CPUSPTAGRNG
::
SetParameters
(
const
Config
&
config
)
{
#define Assign(param_name, str_name) \
conf->param_name == INVALID_VALUE ? index_ptr_->SetParameter(str_name, std::to_string(build_cfg->param_name)) \
: index_ptr_->SetParameter(str_name, std::to_string(conf->param_name))
if
(
index_type_
==
SPTAG
::
IndexAlgoType
::
KDT
)
{
auto
conf
=
std
::
dynamic_pointer_cast
<
KDTCfg
>
(
config
);
auto
build_cfg
=
SPTAGParameterMgr
::
GetInstance
().
GetKDTParameters
();
Assign
(
kdtnumber
,
"KDTNumber"
);
Assign
(
numtopdimensionkdtsplit
,
"NumTopDimensionKDTSplit"
);
Assign
(
samples
,
"Samples"
);
Assign
(
tptnumber
,
"TPTNumber"
);
Assign
(
tptleafsize
,
"TPTLeafSize"
);
Assign
(
numtopdimensiontptsplit
,
"NumTopDimensionTPTSplit"
);
Assign
(
neighborhoodsize
,
"NeighborhoodSize"
);
Assign
(
graphneighborhoodscale
,
"GraphNeighborhoodScale"
);
Assign
(
graphcefscale
,
"GraphCEFScale"
);
Assign
(
refineiterations
,
"RefineIterations"
);
Assign
(
cef
,
"CEF"
);
Assign
(
maxcheckforrefinegraph
,
"MaxCheckForRefineGraph"
);
Assign
(
numofthreads
,
"NumberOfThreads"
);
Assign
(
maxcheck
,
"MaxCheck"
);
Assign
(
thresholdofnumberofcontinuousnobetterpropagation
,
"ThresholdOfNumberOfContinuousNoBetterPropagation"
);
Assign
(
numberofinitialdynamicpivots
,
"NumberOfInitialDynamicPivots"
);
Assign
(
numberofotherdynamicpivots
,
"NumberOfOtherDynamicPivots"
);
}
else
{
auto
conf
=
std
::
dynamic_pointer_cast
<
BKTCfg
>
(
config
);
auto
build_cfg
=
SPTAGParameterMgr
::
GetInstance
().
GetBKTParameters
();
Assign
(
bktnumber
,
"BKTNumber"
);
Assign
(
bktkmeansk
,
"BKTKMeansK"
);
Assign
(
bktleafsize
,
"BKTLeafSize"
);
Assign
(
samples
,
"Samples"
);
Assign
(
tptnumber
,
"TPTNumber"
);
Assign
(
tptleafsize
,
"TPTLeafSize"
);
Assign
(
numtopdimensiontptsplit
,
"NumTopDimensionTPTSplit"
);
Assign
(
neighborhoodsize
,
"NeighborhoodSize"
);
Assign
(
graphneighborhoodscale
,
"GraphNeighborhoodScale"
);
Assign
(
graphcefscale
,
"GraphCEFScale"
);
Assign
(
refineiterations
,
"RefineIterations"
);
Assign
(
cef
,
"CEF"
);
Assign
(
maxcheckforrefinegraph
,
"MaxCheckForRefineGraph"
);
Assign
(
numofthreads
,
"NumberOfThreads"
);
Assign
(
maxcheck
,
"MaxCheck"
);
Assign
(
thresholdofnumberofcontinuousnobetterpropagation
,
"ThresholdOfNumberOfContinuousNoBetterPropagation"
);
Assign
(
numberofinitialdynamicpivots
,
"NumberOfInitialDynamicPivots"
);
Assign
(
numberofotherdynamicpivots
,
"NumberOfOtherDynamicPivots"
);
}
}
DatasetPtr
CPUSPTAGRNG
::
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
{
SetParameters
(
config
);
auto
tensor
=
dataset
->
tensor
()[
0
];
auto
p
=
(
float
*
)
tensor
->
raw_mutable_data
();
for
(
auto
i
=
0
;
i
<
10
;
++
i
)
{
for
(
auto
j
=
0
;
j
<
10
;
++
j
)
{
std
::
cout
<<
p
[
i
*
10
+
j
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
vector
<
SPTAG
::
QueryResult
>
query_results
=
ConvertToQueryResult
(
dataset
,
config
);
#pragma omp parallel for
for
(
auto
i
=
0
;
i
<
query_results
.
size
();
++
i
)
{
auto
target
=
(
float
*
)
query_results
[
i
].
GetTarget
();
std
::
cout
<<
target
[
0
]
<<
", "
<<
target
[
1
]
<<
", "
<<
target
[
2
]
<<
std
::
endl
;
index_ptr_
->
SearchIndex
(
query_results
[
i
]);
}
return
ConvertToDataset
(
query_results
);
}
int64_t
CPUSPTAGRNG
::
Count
()
{
return
index_ptr_
->
GetNumSamples
();
}
int64_t
CPUSPTAGRNG
::
Dimension
()
{
return
index_ptr_
->
GetFeatureDim
();
}
VectorIndexPtr
CPUSPTAGRNG
::
Clone
()
{
KNOWHERE_THROW_MSG
(
"not support"
);
}
void
CPUSPTAGRNG
::
Seal
()
{
return
;
// do nothing
}
BinarySet
CPUSPTAGRNGIndexModel
::
Serialize
()
{
// KNOWHERE_THROW_MSG("not support"); // not support
}
void
CPUSPTAGRNGIndexModel
::
Load
(
const
BinarySet
&
binary
)
{
// KNOWHERE_THROW_MSG("not support"); // not support
}
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/Index
KDT
.h
→
core/src/index/knowhere/knowhere/index/vector_index/Index
SPTAG
.h
浏览文件 @
e01eb29f
...
@@ -18,33 +18,37 @@
...
@@ -18,33 +18,37 @@
#pragma once
#pragma once
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
#include <cstdint>
#include <cstdint>
#include <memory>
#include <memory>
#include <string>
#include "VectorIndex.h"
#include "VectorIndex.h"
#include "knowhere/index/IndexModel.h"
#include "knowhere/index/IndexModel.h"
namespace
knowhere
{
namespace
knowhere
{
class
CPU
KDT
RNG
:
public
VectorIndex
{
class
CPU
SPTAG
RNG
:
public
VectorIndex
{
public:
public:
CPUKDTRNG
()
{
explicit
CPUSPTAGRNG
(
const
std
::
string
&
IndexType
);
index_ptr_
=
SPTAG
::
VectorIndex
::
CreateInstance
(
SPTAG
::
IndexAlgoType
::
KDT
,
SPTAG
::
VectorValueType
::
Float
);
index_ptr_
->
SetParameter
(
"DistCalcMethod"
,
"L2"
);
}
public:
public:
BinarySet
BinarySet
Serialize
()
override
;
Serialize
()
override
;
VectorIndexPtr
VectorIndexPtr
Clone
()
override
;
Clone
()
override
;
void
void
Load
(
const
BinarySet
&
index_array
)
override
;
Load
(
const
BinarySet
&
index_array
)
override
;
public:
public:
// PreprocessorPtr
// PreprocessorPtr
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
// BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override;
int64_t
int64_t
Count
()
override
;
Count
()
override
;
int64_t
int64_t
Dimension
()
override
;
Dimension
()
override
;
...
@@ -56,6 +60,7 @@ class CPUKDTRNG : public VectorIndex {
...
@@ -56,6 +60,7 @@ class CPUKDTRNG : public VectorIndex {
DatasetPtr
DatasetPtr
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
Search
(
const
DatasetPtr
&
dataset
,
const
Config
&
config
)
override
;
void
void
Seal
()
override
;
Seal
()
override
;
...
@@ -66,11 +71,12 @@ class CPUKDTRNG : public VectorIndex {
...
@@ -66,11 +71,12 @@ class CPUKDTRNG : public VectorIndex {
private:
private:
PreprocessorPtr
preprocessor_
;
PreprocessorPtr
preprocessor_
;
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_ptr_
;
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_ptr_
;
SPTAG
::
IndexAlgoType
index_type_
;
};
};
using
CPU
KDTRNGPtr
=
std
::
shared_ptr
<
CPUKDT
RNG
>
;
using
CPU
SPTAGRNGPtr
=
std
::
shared_ptr
<
CPUSPTAG
RNG
>
;
class
CPU
KDT
RNGIndexModel
:
public
IndexModel
{
class
CPU
SPTAG
RNGIndexModel
:
public
IndexModel
{
public:
public:
BinarySet
BinarySet
Serialize
()
override
;
Serialize
()
override
;
...
@@ -82,6 +88,6 @@ class CPUKDTRNGIndexModel : public IndexModel {
...
@@ -82,6 +88,6 @@ class CPUKDTRNGIndexModel : public IndexModel {
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_
;
std
::
shared_ptr
<
SPTAG
::
VectorIndex
>
index_
;
};
};
using
CPU
KDTRNGIndexModelPtr
=
std
::
shared_ptr
<
CPUKDT
RNGIndexModel
>
;
using
CPU
SPTAGRNGIndexModelPtr
=
std
::
shared_ptr
<
CPUSPTAG
RNGIndexModel
>
;
}
// namespace knowhere
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h
浏览文件 @
e01eb29f
...
@@ -42,6 +42,32 @@ constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
...
@@ -42,6 +42,32 @@ constexpr int64_t DEFAULT_OUT_DEGREE = INVALID_VALUE;
constexpr
int64_t
DEFAULT_CANDIDATE_SISE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_CANDIDATE_SISE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NNG_K
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NNG_K
=
INVALID_VALUE
;
// SPTAG Config
constexpr
int64_t
DEFAULT_SAMPLES
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_TPTNUMBER
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_TPTLEAFSIZE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMTOPDIMENSIONTPTSPLIT
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NEIGHBORHOODSIZE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_GRAPHNEIGHBORHOODSCALE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_GRAPHCEFSCALE
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_REFINEITERATIONS
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_CEF
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_MAXCHECKFORREFINEGRAPH
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMOFTHREADS
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_MAXCHECK
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS
=
INVALID_VALUE
;
// KDT Config
constexpr
int64_t
DEFAULT_KDTNUMBER
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_NUMTOPDIMENSIONKDTSPLIT
=
INVALID_VALUE
;
// BKT Config
constexpr
int64_t
DEFAULT_BKTNUMBER
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_BKTKMEANSK
=
INVALID_VALUE
;
constexpr
int64_t
DEFAULT_BKTLEAFSIZE
=
INVALID_VALUE
;
struct
IVFCfg
:
public
Cfg
{
struct
IVFCfg
:
public
Cfg
{
int64_t
nlist
=
DEFAULT_NLIST
;
int64_t
nlist
=
DEFAULT_NLIST
;
int64_t
nprobe
=
DEFAULT_NPROBE
;
int64_t
nprobe
=
DEFAULT_NPROBE
;
...
@@ -135,8 +161,57 @@ struct NSGCfg : public IVFCfg {
...
@@ -135,8 +161,57 @@ struct NSGCfg : public IVFCfg {
};
};
using
NSGConfig
=
std
::
shared_ptr
<
NSGCfg
>
;
using
NSGConfig
=
std
::
shared_ptr
<
NSGCfg
>
;
struct
KDTCfg
:
public
Cfg
{
struct
SPTAGCfg
:
public
Cfg
{
int64_t
tptnubmber
=
-
1
;
int64_t
samples
=
DEFAULT_SAMPLES
;
int64_t
tptnumber
=
DEFAULT_TPTNUMBER
;
int64_t
tptleafsize
=
DEFAULT_TPTLEAFSIZE
;
int64_t
numtopdimensiontptsplit
=
DEFAULT_NUMTOPDIMENSIONTPTSPLIT
;
int64_t
neighborhoodsize
=
DEFAULT_NEIGHBORHOODSIZE
;
int64_t
graphneighborhoodscale
=
DEFAULT_GRAPHNEIGHBORHOODSCALE
;
int64_t
graphcefscale
=
DEFAULT_GRAPHCEFSCALE
;
int64_t
refineiterations
=
DEFAULT_REFINEITERATIONS
;
int64_t
cef
=
DEFAULT_CEF
;
int64_t
maxcheckforrefinegraph
=
DEFAULT_MAXCHECKFORREFINEGRAPH
;
int64_t
numofthreads
=
DEFAULT_NUMOFTHREADS
;
int64_t
maxcheck
=
DEFAULT_MAXCHECK
;
int64_t
thresholdofnumberofcontinuousnobetterpropagation
=
DEFAULT_THRESHOLDOFNUMBEROFCONTINUOUSNOBETTERPROPAGATION
;
int64_t
numberofinitialdynamicpivots
=
DEFAULT_NUMBEROFINITIALDYNAMICPIVOTS
;
int64_t
numberofotherdynamicpivots
=
DEFAULT_NUMBEROFOTHERDYNAMICPIVOTS
;
SPTAGCfg
()
=
default
;
bool
CheckValid
()
override
{
return
true
;
};
};
using
SPTAGConfig
=
std
::
shared_ptr
<
SPTAGCfg
>
;
struct
KDTCfg
:
public
SPTAGCfg
{
int64_t
kdtnumber
=
DEFAULT_KDTNUMBER
;
int64_t
numtopdimensionkdtsplit
=
DEFAULT_NUMTOPDIMENSIONKDTSPLIT
;
KDTCfg
()
=
default
;
bool
CheckValid
()
override
{
return
true
;
};
};
using
KDTConfig
=
std
::
shared_ptr
<
KDTCfg
>
;
struct
BKTCfg
:
public
SPTAGCfg
{
int64_t
bktnumber
=
DEFAULT_BKTNUMBER
;
int64_t
bktkmeansk
=
DEFAULT_BKTKMEANSK
;
int64_t
bktleafsize
=
DEFAULT_BKTLEAFSIZE
;
BKTCfg
()
=
default
;
bool
CheckValid
()
override
{
return
true
;
};
};
};
using
BKTConfig
=
std
::
shared_ptr
<
BKTCfg
>
;
}
// namespace knowhere
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/
KDT
ParameterMgr.cpp
→
core/src/index/knowhere/knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.cpp
浏览文件 @
e01eb29f
...
@@ -17,39 +17,59 @@
...
@@ -17,39 +17,59 @@
#include <mutex>
#include <mutex>
#include "knowhere/index/vector_index/helpers/
KDT
ParameterMgr.h"
#include "knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.h"
namespace
knowhere
{
namespace
knowhere
{
const
std
::
vector
<
KDTParameter
>
&
const
KDTConfig
&
KDT
ParameterMgr
::
GetKDTParameters
()
{
SPTAG
ParameterMgr
::
GetKDTParameters
()
{
return
kdt_
parameters
_
;
return
kdt_
config
_
;
}
}
KDTParameterMgr
::
KDTParameterMgr
()
{
const
BKTConfig
&
kdt_parameters_
=
std
::
vector
<
KDTParameter
>
{
SPTAGParameterMgr
::
GetBKTParameters
()
{
{
"KDTNumber"
,
"1"
},
return
bkt_config_
;
{
"NumTopDimensionKDTSplit"
,
"5"
},
}
{
"NumSamplesKDTSplitConsideration"
,
"100"
},
SPTAGParameterMgr
::
SPTAGParameterMgr
()
{
{
"TPTNumber"
,
"1"
},
kdt_config_
=
std
::
make_shared
<
KDTCfg
>
();
{
"TPTLeafSize"
,
"2000"
},
kdt_config_
->
kdtnumber
=
1
;
{
"NumTopDimensionTPTSplit"
,
"5"
},
kdt_config_
->
numtopdimensionkdtsplit
=
5
;
kdt_config_
->
samples
=
100
;
{
"NeighborhoodSize"
,
"32"
},
kdt_config_
->
tptnumber
=
1
;
{
"GraphNeighborhoodScale"
,
"2"
},
kdt_config_
->
tptleafsize
=
2000
;
{
"GraphCEFScale"
,
"2"
},
kdt_config_
->
numtopdimensiontptsplit
=
5
;
{
"RefineIterations"
,
"0"
},
kdt_config_
->
neighborhoodsize
=
32
;
{
"CEF"
,
"1000"
},
kdt_config_
->
graphneighborhoodscale
=
2
;
{
"MaxCheckForRefineGraph"
,
"10000"
},
kdt_config_
->
graphcefscale
=
2
;
kdt_config_
->
refineiterations
=
0
;
{
"NumberOfThreads"
,
"1"
},
kdt_config_
->
cef
=
1000
;
kdt_config_
->
maxcheckforrefinegraph
=
10000
;
{
"MaxCheck"
,
"8192"
},
kdt_config_
->
numofthreads
=
1
;
{
"ThresholdOfNumberOfContinuousNoBetterPropagation"
,
"3"
},
kdt_config_
->
maxcheck
=
8192
;
{
"NumberOfInitialDynamicPivots"
,
"50"
},
kdt_config_
->
thresholdofnumberofcontinuousnobetterpropagation
=
3
;
{
"NumberOfOtherDynamicPivots"
,
"4"
},
kdt_config_
->
numberofinitialdynamicpivots
=
50
;
};
kdt_config_
->
numberofotherdynamicpivots
=
4
;
bkt_config_
=
std
::
make_shared
<
BKTCfg
>
();
bkt_config_
->
bktnumber
=
1
;
bkt_config_
->
bktkmeansk
=
32
;
bkt_config_
->
bktleafsize
=
8
;
bkt_config_
->
samples
=
100
;
bkt_config_
->
tptnumber
=
1
;
bkt_config_
->
tptleafsize
=
2000
;
bkt_config_
->
numtopdimensiontptsplit
=
5
;
bkt_config_
->
neighborhoodsize
=
32
;
bkt_config_
->
graphneighborhoodscale
=
2
;
bkt_config_
->
graphcefscale
=
2
;
bkt_config_
->
refineiterations
=
0
;
bkt_config_
->
cef
=
1000
;
bkt_config_
->
maxcheckforrefinegraph
=
10000
;
bkt_config_
->
numofthreads
=
1
;
bkt_config_
->
maxcheck
=
8192
;
bkt_config_
->
thresholdofnumberofcontinuousnobetterpropagation
=
3
;
bkt_config_
->
numberofinitialdynamicpivots
=
50
;
bkt_config_
->
numberofotherdynamicpivots
=
4
;
}
}
}
// namespace knowhere
}
// namespace knowhere
core/src/index/knowhere/knowhere/index/vector_index/helpers/
KDT
ParameterMgr.h
→
core/src/index/knowhere/knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.h
浏览文件 @
e01eb29f
...
@@ -22,31 +22,40 @@
...
@@ -22,31 +22,40 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <SPTAG/AnnService/inc/Core/Common.h>
#include "IndexParameter.h"
namespace
knowhere
{
namespace
knowhere
{
using
KDTParameter
=
std
::
pair
<
std
::
string
,
std
::
string
>
;
using
KDTConfig
=
std
::
shared_ptr
<
KDTCfg
>
;
using
BKTConfig
=
std
::
shared_ptr
<
BKTCfg
>
;
class
KDT
ParameterMgr
{
class
SPTAG
ParameterMgr
{
public:
public:
const
std
::
vector
<
KDTParameter
>
&
const
KDTConfig
&
GetKDTParameters
();
GetKDTParameters
();
const
BKTConfig
&
GetBKTParameters
();
public:
public:
static
KDT
ParameterMgr
&
static
SPTAG
ParameterMgr
&
GetInstance
()
{
GetInstance
()
{
static
KDT
ParameterMgr
instance
;
static
SPTAG
ParameterMgr
instance
;
return
instance
;
return
instance
;
}
}
KDTParameterMgr
(
const
KDTParameterMgr
&
)
=
delete
;
SPTAGParameterMgr
(
const
SPTAGParameterMgr
&
)
=
delete
;
KDTParameterMgr
&
operator
=
(
const
KDTParameterMgr
&
)
=
delete
;
SPTAGParameterMgr
&
operator
=
(
const
SPTAGParameterMgr
&
)
=
delete
;
private:
private:
KDT
ParameterMgr
();
SPTAG
ParameterMgr
();
private:
private:
std
::
vector
<
KDTParameter
>
kdt_parameters_
;
KDTConfig
kdt_config_
;
BKTConfig
bkt_config_
;
};
};
}
// namespace knowhere
}
// namespace knowhere
core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h
浏览文件 @
e01eb29f
...
@@ -195,7 +195,7 @@ namespace SPTAG
...
@@ -195,7 +195,7 @@ namespace SPTAG
C
=
*
((
DimensionType
*
)
pDataPointsMemFile
);
C
=
*
((
DimensionType
*
)
pDataPointsMemFile
);
pDataPointsMemFile
+=
sizeof
(
DimensionType
);
pDataPointsMemFile
+=
sizeof
(
DimensionType
);
Initialize
(
R
,
C
,
(
T
*
)
pDataPointsMemFile
);
Initialize
(
R
,
C
,
(
T
*
)
pDataPointsMemFile
,
false
);
std
::
cout
<<
"Load "
<<
name
<<
" ("
<<
R
<<
", "
<<
C
<<
") Finish!"
<<
std
::
endl
;
std
::
cout
<<
"Load "
<<
name
<<
" ("
<<
R
<<
", "
<<
C
<<
") Finish!"
<<
std
::
endl
;
return
true
;
return
true
;
}
}
...
...
core/src/index/unittest/CMakeLists.txt
浏览文件 @
e01eb29f
...
@@ -82,17 +82,17 @@ if (NOT TARGET test_idmap)
...
@@ -82,17 +82,17 @@ if (NOT TARGET test_idmap)
endif
()
endif
()
target_link_libraries
(
test_idmap
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
target_link_libraries
(
test_idmap
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
#<
KDT
-TEST>
#<
SPTAG
-TEST>
set
(
kdt
_srcs
set
(
sptag
_srcs
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/adapter/SptagAdapter.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/adapter/SptagAdapter.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/preprocessor/Normalize.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/preprocessor/Normalize.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/helpers/
KDT
ParameterMgr.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/helpers/
SPTAG
ParameterMgr.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/Index
KDT
.cpp
${
INDEX_SOURCE_DIR
}
/knowhere/knowhere/index/vector_index/Index
SPTAG
.cpp
)
)
if
(
NOT TARGET test_
kdt
)
if
(
NOT TARGET test_
sptag
)
add_executable
(
test_
kdt test_kdt.cpp
${
kdt
_srcs
}
${
util_srcs
}
)
add_executable
(
test_
sptag test_sptag.cpp
${
sptag
_srcs
}
${
util_srcs
}
)
endif
()
endif
()
target_link_libraries
(
test_
kdt
target_link_libraries
(
test_
sptag
SPTAGLibStatic
SPTAGLibStatic
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
${
depend_libs
}
${
unittest_libs
}
${
basic_libs
}
)
...
@@ -106,7 +106,7 @@ endif ()
...
@@ -106,7 +106,7 @@ endif ()
install
(
TARGETS test_ivf DESTINATION unittest
)
install
(
TARGETS test_ivf DESTINATION unittest
)
install
(
TARGETS test_idmap DESTINATION unittest
)
install
(
TARGETS test_idmap DESTINATION unittest
)
install
(
TARGETS test_
kdt
DESTINATION unittest
)
install
(
TARGETS test_
sptag
DESTINATION unittest
)
if
(
KNOWHERE_GPU_VERSION
)
if
(
KNOWHERE_GPU_VERSION
)
install
(
TARGETS test_gpuresource DESTINATION unittest
)
install
(
TARGETS test_gpuresource DESTINATION unittest
)
install
(
TARGETS test_customized_index DESTINATION unittest
)
install
(
TARGETS test_customized_index DESTINATION unittest
)
...
...
core/src/index/unittest/test_
kdt
.cpp
→
core/src/index/unittest/test_
sptag
.cpp
浏览文件 @
e01eb29f
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/adapter/Structure.h"
#include "knowhere/adapter/Structure.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/Index
KDT
.h"
#include "knowhere/index/vector_index/Index
SPTAG
.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "unittest/utils.h"
#include "unittest/utils.h"
...
@@ -32,28 +32,38 @@ using ::testing::Combine;
...
@@ -32,28 +32,38 @@ using ::testing::Combine;
using
::
testing
::
TestWithParam
;
using
::
testing
::
TestWithParam
;
using
::
testing
::
Values
;
using
::
testing
::
Values
;
class
KDTTest
:
public
DataGen
,
public
::
testing
::
Test
{
class
SPTAGTest
:
public
DataGen
,
public
TestWithParam
<
std
::
string
>
{
protected:
protected:
void
void
SetUp
()
override
{
SetUp
()
override
{
Generate
(
96
,
1000
,
10
);
IndexType
=
GetParam
();
index_
=
std
::
make_shared
<
knowhere
::
CPUKDTRNG
>
();
Generate
(
128
,
100
,
5
);
index_
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
IndexType
);
if
(
IndexType
==
"KDT"
)
{
auto
tempconf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
auto
tempconf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
tempconf
->
tptnubmber
=
1
;
tempconf
->
tptnumber
=
1
;
tempconf
->
k
=
10
;
conf
=
tempconf
;
}
else
{
auto
tempconf
=
std
::
make_shared
<
knowhere
::
BKTCfg
>
();
tempconf
->
tptnumber
=
1
;
tempconf
->
k
=
10
;
tempconf
->
k
=
10
;
conf
=
tempconf
;
conf
=
tempconf
;
}
Init_with_default
();
Init_with_default
();
}
}
protected:
protected:
knowhere
::
Config
conf
;
knowhere
::
Config
conf
;
std
::
shared_ptr
<
knowhere
::
CPUKDTRNG
>
index_
=
nullptr
;
std
::
shared_ptr
<
knowhere
::
CPUSPTAGRNG
>
index_
=
nullptr
;
std
::
string
IndexType
;
};
};
INSTANTIATE_TEST_CASE_P
(
SPTAGParameters
,
SPTAGTest
,
Values
(
"KDT"
,
"BKT"
));
// TODO(lxj): add test about count() and dimension()
// TODO(lxj): add test about count() and dimension()
TEST_
F
(
KDTTest
,
kdt
_basic
)
{
TEST_
P
(
SPTAGTest
,
sptag
_basic
)
{
assert
(
!
xb
.
empty
());
assert
(
!
xb
.
empty
());
auto
preprocessor
=
index_
->
BuildPreprocessor
(
base_dataset
,
conf
);
auto
preprocessor
=
index_
->
BuildPreprocessor
(
base_dataset
,
conf
);
...
@@ -75,10 +85,10 @@ TEST_F(KDTTest, kdt_basic) {
...
@@ -75,10 +85,10 @@ TEST_F(KDTTest, kdt_basic) {
std
::
stringstream
ss_dist
;
std
::
stringstream
ss_dist
;
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
ss_id
<<
*
((
int64_t
*
)(
ids
)
+
i
*
k
+
j
)
<<
" "
;
ss_dist
<<
*
((
float
*
)(
dists
)
+
i
*
k
+
j
)
<<
" "
;
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
// ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
ss_id
<<
*
((
int64_t
*
)(
ids
)
+
i
*
k
+
j
)
<<
" "
;
ss_dist
<<
*
((
float
*
)(
dists
)
+
i
*
k
+
j
)
<<
" "
;
}
}
ss_id
<<
std
::
endl
;
ss_id
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
...
@@ -88,57 +98,57 @@ TEST_F(KDTTest, kdt_basic) {
...
@@ -88,57 +98,57 @@ TEST_F(KDTTest, kdt_basic) {
}
}
}
}
// TODO(zirui): enable test
TEST_P
(
SPTAGTest
,
sptag_serialize
)
{
// TEST_F(KDTTest, kdt_serialize) {
assert
(
!
xb
.
empty
());
// assert(!xb.empty());
//
auto
preprocessor
=
index_
->
BuildPreprocessor
(
base_dataset
,
conf
);
// auto preprocessor = index_->BuildPreprocessor(base_dataset, conf
);
index_
->
set_preprocessor
(
preprocessor
);
// index_->set_preprocessor(preprocessor);
//
auto
model
=
index_
->
Train
(
base_dataset
,
conf
);
// auto model = index_->Train(base_dataset, conf);
// //
index_->Add(base_dataset, conf);
index_
->
Add
(
base_dataset
,
conf
);
//
auto binaryset = index_->Serialize();
auto
binaryset
=
index_
->
Serialize
();
// auto new_index = std::make_shared<knowhere::CPUKDTRNG>(
);
auto
new_index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
IndexType
);
//
new_index->Load(binaryset);
new_index
->
Load
(
binaryset
);
//
auto result = new_index->Search(query_dataset, conf);
auto
result
=
new_index
->
Search
(
query_dataset
,
conf
);
//
AssertAnns(result, nq, k);
AssertAnns
(
result
,
nq
,
k
);
//
PrintResult(result, nq, k);
PrintResult
(
result
,
nq
,
k
);
//
ASSERT_EQ(new_index->Count(), nb);
ASSERT_EQ
(
new_index
->
Count
(),
nb
);
//
ASSERT_EQ(new_index->Dimension(), dim);
ASSERT_EQ
(
new_index
->
Dimension
(),
dim
);
//
ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException);
//
ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException);
//
ASSERT_NO_THROW({ new_index->Seal(); });
//
ASSERT_NO_THROW({ new_index->Seal(); });
//
//
{
{
//
int fileno = 0;
int
fileno
=
0
;
// const std::string& base_name = "/tmp/kdt
_serialize_test_bin_";
const
std
::
string
&
base_name
=
"/tmp/sptag
_serialize_test_bin_"
;
//
std::vector<std::string> filename_list;
std
::
vector
<
std
::
string
>
filename_list
;
//
std::vector<std::pair<std::string, size_t>> meta_list;
std
::
vector
<
std
::
pair
<
std
::
string
,
size_t
>>
meta_list
;
//
for (auto& iter : binaryset.binary_map_) {
for
(
auto
&
iter
:
binaryset
.
binary_map_
)
{
//
const std::string& filename = base_name + std::to_string(fileno);
const
std
::
string
&
filename
=
base_name
+
std
::
to_string
(
fileno
);
//
FileIOWriter writer(filename);
FileIOWriter
writer
(
filename
);
//
writer(iter.second->data.get(), iter.second->size);
writer
(
iter
.
second
->
data
.
get
(),
iter
.
second
->
size
);
//
//
meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
meta_list
.
emplace_back
(
std
::
make_pair
(
iter
.
first
,
iter
.
second
->
size
));
//
filename_list.push_back(filename);
filename_list
.
push_back
(
filename
);
//
++fileno;
++
fileno
;
//
}
}
//
//
knowhere::BinarySet load_data_list;
knowhere
::
BinarySet
load_data_list
;
//
for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
for
(
int
i
=
0
;
i
<
filename_list
.
size
()
&&
i
<
meta_list
.
size
();
++
i
)
{
//
auto bin_size = meta_list[i].second;
auto
bin_size
=
meta_list
[
i
].
second
;
//
FileIOReader reader(filename_list[i]);
FileIOReader
reader
(
filename_list
[
i
]);
//
//
auto load_data = new uint8_t[bin_size];
auto
load_data
=
new
uint8_t
[
bin_size
];
//
reader(load_data, bin_size);
reader
(
load_data
,
bin_size
);
//
auto data = std::make_shared<uint8_t>();
auto
data
=
std
::
make_shared
<
uint8_t
>
();
//
data.reset(load_data);
data
.
reset
(
load_data
);
//
load_data_list.Append(meta_list[i].first, data, bin_size);
load_data_list
.
Append
(
meta_list
[
i
].
first
,
data
,
bin_size
);
//
}
}
//
// auto new_index = std::make_shared<knowhere::CPUKDTRNG>(
);
auto
new_index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
IndexType
);
//
new_index->Load(load_data_list);
new_index
->
Load
(
load_data_list
);
//
auto result = new_index->Search(query_dataset, conf);
auto
result
=
new_index
->
Search
(
query_dataset
,
conf
);
//
AssertAnns(result, nq, k);
AssertAnns
(
result
,
nq
,
k
);
//
PrintResult(result, nq, k);
PrintResult
(
result
,
nq
,
k
);
//
}
}
//
}
}
core/src/index/unittest/utils.cpp
浏览文件 @
e01eb29f
...
@@ -160,15 +160,17 @@ AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
...
@@ -160,15 +160,17 @@ AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) {
void
void
PrintResult
(
const
knowhere
::
DatasetPtr
&
result
,
const
int
&
nq
,
const
int
&
k
)
{
PrintResult
(
const
knowhere
::
DatasetPtr
&
result
,
const
int
&
nq
,
const
int
&
k
)
{
auto
ids
=
result
->
array
()[
0
]
;
auto
ids
=
result
->
ids
()
;
auto
dists
=
result
->
array
()[
1
]
;
auto
dists
=
result
->
dist
()
;
std
::
stringstream
ss_id
;
std
::
stringstream
ss_id
;
std
::
stringstream
ss_dist
;
std
::
stringstream
ss_dist
;
for
(
auto
i
=
0
;
i
<
10
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
nq
;
i
++
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
for
(
auto
j
=
0
;
j
<
k
;
++
j
)
{
ss_id
<<
*
(
ids
->
data
()
->
GetValues
<
int64_t
>
(
1
,
i
*
k
+
j
))
<<
" "
;
// ss_id << *(ids->data()->GetValues<int64_t>(1, i * k + j)) << " ";
ss_dist
<<
*
(
dists
->
data
()
->
GetValues
<
float
>
(
1
,
i
*
k
+
j
))
<<
" "
;
// ss_dist << *(dists->data()->GetValues<float>(1, i * k + j)) << " ";
ss_id
<<
*
((
int64_t
*
)(
ids
)
+
i
*
k
+
j
)
<<
" "
;
ss_dist
<<
*
((
float
*
)(
dists
)
+
i
*
k
+
j
)
<<
" "
;
}
}
ss_id
<<
std
::
endl
;
ss_id
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
ss_dist
<<
std
::
endl
;
...
...
core/src/wrapper/ConfAdapter.cpp
浏览文件 @
e01eb29f
...
@@ -201,5 +201,35 @@ NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type)
...
@@ -201,5 +201,35 @@ NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type)
return
conf
;
return
conf
;
}
}
knowhere
::
Config
SPTAGKDTConfAdapter
::
Match
(
const
TempMetaConf
&
metaconf
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
conf
->
d
=
metaconf
.
dim
;
conf
->
metric_type
=
metaconf
.
metric_type
;
return
conf
;
}
knowhere
::
Config
SPTAGKDTConfAdapter
::
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
KDTCfg
>
();
conf
->
k
=
metaconf
.
k
;
return
conf
;
}
knowhere
::
Config
SPTAGBKTConfAdapter
::
Match
(
const
TempMetaConf
&
metaconf
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
BKTCfg
>
();
conf
->
d
=
metaconf
.
dim
;
conf
->
metric_type
=
metaconf
.
metric_type
;
return
conf
;
}
knowhere
::
Config
SPTAGBKTConfAdapter
::
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
{
auto
conf
=
std
::
make_shared
<
knowhere
::
BKTCfg
>
();
conf
->
k
=
metaconf
.
k
;
return
conf
;
}
}
// namespace engine
}
// namespace engine
}
// namespace milvus
}
// namespace milvus
core/src/wrapper/ConfAdapter.h
浏览文件 @
e01eb29f
...
@@ -94,5 +94,23 @@ class NSGConfAdapter : public IVFConfAdapter {
...
@@ -94,5 +94,23 @@ class NSGConfAdapter : public IVFConfAdapter {
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
final
;
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
final
;
};
};
class
SPTAGKDTConfAdapter
:
public
ConfAdapter
{
public:
knowhere
::
Config
Match
(
const
TempMetaConf
&
metaconf
)
override
;
knowhere
::
Config
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
override
;
};
class
SPTAGBKTConfAdapter
:
public
ConfAdapter
{
public:
knowhere
::
Config
Match
(
const
TempMetaConf
&
metaconf
)
override
;
knowhere
::
Config
MatchSearch
(
const
TempMetaConf
&
metaconf
,
const
IndexType
&
type
)
override
;
};
}
// namespace engine
}
// namespace engine
}
// namespace milvus
}
// namespace milvus
core/src/wrapper/ConfAdapterMgr.cpp
浏览文件 @
e01eb29f
...
@@ -56,6 +56,9 @@ AdapterMgr::RegisterAdapter() {
...
@@ -56,6 +56,9 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER
(
IVFPQConfAdapter
,
IndexType
::
FAISS_IVFPQ_MIX
,
ivfpq_mix
);
REGISTER_CONF_ADAPTER
(
IVFPQConfAdapter
,
IndexType
::
FAISS_IVFPQ_MIX
,
ivfpq_mix
);
REGISTER_CONF_ADAPTER
(
NSGConfAdapter
,
IndexType
::
NSG_MIX
,
nsg_mix
);
REGISTER_CONF_ADAPTER
(
NSGConfAdapter
,
IndexType
::
NSG_MIX
,
nsg_mix
);
REGISTER_CONF_ADAPTER
(
SPTAGKDTConfAdapter
,
IndexType
::
SPTAG_KDT_RNT_CPU
,
sptag_kdt
);
REGISTER_CONF_ADAPTER
(
SPTAGBKTConfAdapter
,
IndexType
::
SPTAG_BKT_RNT_CPU
,
sptag_bkt
);
}
}
}
// namespace engine
}
// namespace engine
...
...
core/src/wrapper/VecImpl.cpp
浏览文件 @
e01eb29f
...
@@ -26,8 +26,8 @@
...
@@ -26,8 +26,8 @@
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#endif
#endif
...
...
core/src/wrapper/VecIndex.cpp
浏览文件 @
e01eb29f
...
@@ -22,8 +22,8 @@
...
@@ -22,8 +22,8 @@
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexKDT.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/index/vector_index/IndexNSG.h"
#include "knowhere/index/vector_index/IndexSPTAG.h"
#include "utils/Log.h"
#include "utils/Log.h"
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
...
@@ -128,7 +128,11 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
...
@@ -128,7 +128,11 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
break
;
break
;
}
}
case
IndexType
::
SPTAG_KDT_RNT_CPU
:
{
case
IndexType
::
SPTAG_KDT_RNT_CPU
:
{
index
=
std
::
make_shared
<
knowhere
::
CPUKDTRNG
>
();
index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
"KDT"
);
break
;
}
case
IndexType
::
SPTAG_BKT_RNT_CPU
:
{
index
=
std
::
make_shared
<
knowhere
::
CPUSPTAGRNG
>
(
"BKT"
);
break
;
break
;
}
}
case
IndexType
::
FAISS_IVFSQ8_CPU
:
{
case
IndexType
::
FAISS_IVFSQ8_CPU
:
{
...
...
core/src/wrapper/VecIndex.h
浏览文件 @
e01eb29f
...
@@ -49,6 +49,7 @@ enum class IndexType {
...
@@ -49,6 +49,7 @@ enum class IndexType {
FAISS_IVFSQ8_HYBRID
,
// only support build on gpu.
FAISS_IVFSQ8_HYBRID
,
// only support build on gpu.
NSG_MIX
,
NSG_MIX
,
FAISS_IVFPQ_MIX
,
FAISS_IVFPQ_MIX
,
SPTAG_BKT_RNT_CPU
,
};
};
class
VecIndex
;
class
VecIndex
;
...
@@ -139,6 +140,9 @@ write_index(VecIndexPtr index, const std::string& location);
...
@@ -139,6 +140,9 @@ write_index(VecIndexPtr index, const std::string& location);
extern
VecIndexPtr
extern
VecIndexPtr
read_index
(
const
std
::
string
&
location
);
read_index
(
const
std
::
string
&
location
);
VecIndexPtr
read_index
(
const
std
::
string
&
location
,
knowhere
::
BinarySet
&
index_binary
);
extern
VecIndexPtr
extern
VecIndexPtr
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
=
Config
());
GetVecIndexFactory
(
const
IndexType
&
type
,
const
Config
&
cfg
=
Config
());
...
...
core/unittest/wrapper/test_wrapper.cpp
浏览文件 @
e01eb29f
...
@@ -16,28 +16,29 @@
...
@@ -16,28 +16,29 @@
// under the License.
// under the License.
#include "easyloggingpp/easylogging++.h"
#include "easyloggingpp/easylogging++.h"
#include "wrapper/VecIndex.h"
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
#endif
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "wrapper/VecIndex.h"
#include "wrapper/utils.h"
#include "wrapper/utils.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
INITIALIZE_EASYLOGGINGPP
INITIALIZE_EASYLOGGINGPP
using
::
testing
::
Combine
;
using
::
testing
::
TestWithParam
;
using
::
testing
::
TestWithParam
;
using
::
testing
::
Values
;
using
::
testing
::
Values
;
using
::
testing
::
Combine
;
class
KnowhereWrapperTest
class
KnowhereWrapperTest
:
public
DataGenBase
,
:
public
DataGenBase
,
public
TestWithParam
<::
std
::
tuple
<
milvus
::
engine
::
IndexType
,
std
::
string
,
int
,
int
,
int
,
int
>>
{
public
TestWithParam
<::
std
::
tuple
<
milvus
::
engine
::
IndexType
,
std
::
string
,
int
,
int
,
int
,
int
>>
{
protected:
protected:
void
SetUp
()
override
{
void
SetUp
()
override
{
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
InitDevice
(
DEVICEID
,
PINMEM
,
TEMPMEM
,
RESNUM
);
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
InitDevice
(
DEVICEID
,
PINMEM
,
TEMPMEM
,
RESNUM
);
#endif
#endif
...
@@ -58,7 +59,8 @@ class KnowhereWrapperTest
...
@@ -58,7 +59,8 @@ class KnowhereWrapperTest
searchconf
=
ParamGenerator
::
GetInstance
().
GenSearchConf
(
index_type
,
tempconf
);
searchconf
=
ParamGenerator
::
GetInstance
().
GenSearchConf
(
index_type
,
tempconf
);
}
}
void
TearDown
()
override
{
void
TearDown
()
override
{
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
Free
();
knowhere
::
FaissGpuResourceMgr
::
GetInstance
().
Free
();
#endif
#endif
...
@@ -71,22 +73,20 @@ class KnowhereWrapperTest
...
@@ -71,22 +73,20 @@ class KnowhereWrapperTest
knowhere
::
Config
searchconf
;
knowhere
::
Config
searchconf
;
};
};
INSTANTIATE_TEST_CASE_P
(
WrapperParam
,
KnowhereWrapperTest
,
INSTANTIATE_TEST_CASE_P
(
WrapperParam
,
KnowhereWrapperTest
,
Values
(
Values
(
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
#ifdef MILVUS_GPU_VERSION
#ifdef MILVUS_GPU_VERSION
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
// std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_GPU, "Default", DIM, NB,
// 10, 10),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_GPU
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_MIX
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_MIX
,
"Default"
,
DIM
,
NB
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFPQ_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFPQ_MIX
,
"Default"
,
64
,
1000
,
10
,
10
),
// std::make_tuple(milvus::engine::IndexType::NSG_MIX, "Default", 128, 250000, 10, 10),
// std::make_tuple(IndexType::NSG_MIX, "Default", 128, 250000, 10, 10),
#endif
#endif
// std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 250000, 10, 10),
// std::make_tuple(milvus::engine::IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 100, 10, 10),
// std::make_tuple(milvus::engine::IndexType::SPTAG_BKT_RNT_CPU, "Default", 128, 100, 10, 10),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IDMAP
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IDMAP
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_CPU
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFFLAT_CPU
,
"Default"
,
64
,
1000
,
10
,
10
),
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_CPU
,
"Default"
,
DIM
,
NB
,
10
,
10
)));
std
::
make_tuple
(
milvus
::
engine
::
IndexType
::
FAISS_IVFSQ8_CPU
,
"Default"
,
DIM
,
NB
,
10
,
10
)));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录