提交 8be48142 编写于 作者: X xiaojun.lin

Merge remote-tracking branch 'upstream/branch-0.5.0' into branch-0.5.1


Former-commit-id: eed45d9c3d43d8058e8223bc64e8ef2b687135b2
......@@ -18,3 +18,10 @@
BasedOnStyle: Google
DerivePointerAlignment: false
ColumnLimit: 120
IndentWidth: 4
AccessModifierOffset: -3
AlwaysBreakAfterReturnType: All
AllowShortBlocksOnASingleLine: false
AllowShortFunctionsOnASingleLine: false
AllowShortIfStatementsOnASingleLine: false
AlignTrailingComments: true
......@@ -18,8 +18,10 @@
Checks: 'clang-diagnostic-*,clang-analyzer-*,-clang-analyzer-alpha*,google-*,modernize-*,readability-*'
# produce HeaderFilterRegex from cpp/build-support/lint_exclusions.txt with:
# echo -n '^('; sed -e 's/*/\.*/g' cpp/build-support/lint_exclusions.txt | tr '\n' '|'; echo ')$'
HeaderFilterRegex: '^(.*cmake-build-debug.*|.*cmake-build-release.*|.*cmake_build.*|.*src/thirdparty.*|.*src/core/thirdparty.*|.*src/grpc.*|)$'
HeaderFilterRegex: '^(.*cmake-build-debug.*|.*cmake-build-release.*|.*cmake_build.*|.*src/core/thirdparty.*|.*thirdparty.*|.*easylogging++.*|.*SqliteMetaImpl.cpp|.*src/grpc.*|.*src/core.*|.*src/wrapper.*)$'
AnalyzeTemporaryDtors: true
ChainedConditionalReturn: 1
ChainedConditionalAssignment: 1
CheckOptions:
- key: google-readability-braces-around-statements.ShortStatementLines
value: '1'
......
......@@ -9,6 +9,7 @@ container('milvus-build-env') {
sh "git config --global user.email \"test@zilliz.com\""
sh "git config --global user.name \"test\""
withCredentials([usernamePassword(credentialsId: "${params.JFROG_USER}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
sh "./build.sh -l"
sh "export JFROG_ARTFACTORY_URL='${params.JFROG_ARTFACTORY_URL}' && export JFROG_USER_NAME='${USERNAME}' && export JFROG_PASSWORD='${PASSWORD}' && ./build.sh -t ${params.BUILD_TYPE} -j -u -c"
}
}
......
......@@ -9,6 +9,7 @@ container('milvus-build-env') {
sh "git config --global user.email \"test@zilliz.com\""
sh "git config --global user.name \"test\""
withCredentials([usernamePassword(credentialsId: "${params.JFROG_USER}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
sh "./build.sh -l"
sh "export JFROG_ARTFACTORY_URL='${params.JFROG_ARTFACTORY_URL}' && export JFROG_USER_NAME='${USERNAME}' && export JFROG_PASSWORD='${PASSWORD}' && ./build.sh -t ${params.BUILD_TYPE} -j"
}
}
......
......@@ -33,12 +33,30 @@ pipeline {
cloud 'build-kubernetes'
label 'build'
defaultContainer 'jnlp'
containerTemplate {
name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true
command 'cat'
}
yaml """
apiVersion: v1
kind: Pod
metadata:
name: milvus-build-env
labels:
app: milvus
componet: build-env
spec:
containers:
- name: milvus-build-env
image: registry.zilliz.com/milvus/milvus-build-env:v0.13
command:
- cat
tty: true
resources:
limits:
memory: "28Gi"
cpu: "10.0"
nvidia.com/gpu: 1
requests:
memory: "14Gi"
cpu: "5.0"
"""
}
}
stages {
......
......@@ -33,12 +33,30 @@ pipeline {
cloud 'build-kubernetes'
label 'build'
defaultContainer 'jnlp'
containerTemplate {
name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true
command 'cat'
}
yaml """
apiVersion: v1
kind: Pod
metadata:
name: milvus-build-env
labels:
app: milvus
componet: build-env
spec:
containers:
- name: milvus-build-env
image: registry.zilliz.com/milvus/milvus-build-env:v0.13
command:
- cat
tty: true
resources:
limits:
memory: "28Gi"
cpu: "10.0"
nvidia.com/gpu: 1
requests:
memory: "14Gi"
cpu: "5.0"
"""
}
}
stages {
......
......@@ -33,12 +33,30 @@ pipeline {
cloud 'build-kubernetes'
label 'build'
defaultContainer 'jnlp'
containerTemplate {
name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true
command 'cat'
}
yaml """
apiVersion: v1
kind: Pod
metadata:
name: milvus-build-env
labels:
app: milvus
componet: build-env
spec:
containers:
- name: milvus-build-env
image: registry.zilliz.com/milvus/milvus-build-env:v0.13
command:
- cat
tty: true
resources:
limits:
memory: "28Gi"
cpu: "10.0"
nvidia.com/gpu: 1
requests:
memory: "14Gi"
cpu: "5.0"
"""
}
}
stages {
......
......@@ -9,21 +9,32 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-572 - Milvus crash when get SIGINT
- MS-577 - Unittest Query randomly hung
- MS-587 - Count get wrong result after adding vectors and index built immediately
- MS-599 - search wrong result when table created with metric_type: IP
- MS-601 - Docker logs error caused by get CPUTemperature error
- MS-622 - Delete vectors should be failed if date range is invalid
- MS-620 - Get table row counts display wrong error code
## Improvement
- MS-552 - Add and change the easylogging library
- MS-553 - Refine cache code
- MS-557 - Merge Log.h
- MS-555 - Remove old scheduler
- MS-556 - Add Job Definition in Scheduler
- MS-557 - Merge Log.h
- MS-558 - Refine status code
- MS-562 - Add JobMgr and TaskCreator in Scheduler
- MS-566 - Refactor cmake
- MS-555 - Remove old scheduler
- MS-574 - Milvus configuration refactor
- MS-578 - Make sure milvus5.0 don't crack 0.3.1 data
- MS-585 - Update namespace in scheduler
- MS-606 - Speed up result reduce
- MS-608 - Update TODO names
- MS-609 - Update task construct function
- MS-611 - Add resources validity check in ResourceMgr
- MS-619 - Add optimizer class in scheduler
- MS-614 - Preload table at startup
## New Feature
- MS-627 - Integrate new index: IVFSQHybrid
## Task
- MS-554 - Change license to Apache 2.0
......@@ -33,6 +44,9 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-575 - Add Clang-format & Clang-tidy & Cpplint
- MS-586 - Remove BUILD_FAISS_WITH_MKL option
- MS-590 - Refine cmake code to support cpplint
- MS-600 - Reconstruct unittest code
- MS-602 - Remove zilliz namespace
- MS-610 - Change error code base value from hex to decimal
# Milvus 0.4.0 (2019-09-12)
......
......@@ -189,7 +189,7 @@ add_custom_target(lint
--exclude_globs
${LINT_EXCLUSIONS_FILE}
--source_dir
${CMAKE_CURRENT_SOURCE_DIR}/src
${CMAKE_CURRENT_SOURCE_DIR}
${MILVUS_LINT_QUIET})
#
......
......@@ -105,15 +105,31 @@ please reinstall CMake with curl:
```
##### code format and linting
Install clang-format and clang-tidy
```shell
CentOS 7:
$ yum install clang
Ubuntu 16.04 or 18.04:
$ sudo apt-get install clang-format clang-tidy
Ubuntu 16.04:
$ sudo apt-get install clang-tidy
$ sudo su
$ wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
$ apt-add-repository "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-6.0 main"
$ apt-get update
$ apt-get install clang-format-6.0
Ubuntu 18.04:
$ sudo apt-get install clang-tidy clang-format
$ rm cmake_build/CMakeCache.txt
```
Check code style
```shell
$ ./build.sh -l
```
To format the code
```shell
$ cd cmake_build
$ make clang-format
```
##### Run unit test
......@@ -122,13 +138,14 @@ $ ./build.sh -u
```
##### Run code coverage
Install lcov
```shell
CentOS 7:
$ yum install lcov
Ubuntu 16.04 or 18.04:
$ sudo apt-get install lcov
```
```shell
$ ./build.sh -u -c
```
......
<code_scheme name="milvus" version="173">
<Objective-C>
<option name="INDENT_NAMESPACE_MEMBERS" value="0" />
<option name="INDENT_VISIBILITY_KEYWORDS" value="1" />
<option name="KEEP_STRUCTURES_IN_ONE_LINE" value="true" />
<option name="KEEP_CASE_EXPRESSIONS_IN_ONE_LINE" value="true" />
<option name="FUNCTION_NON_TOP_AFTER_RETURN_TYPE_WRAP" value="0" />
<option name="FUNCTION_TOP_AFTER_RETURN_TYPE_WRAP" value="2" />
<option name="FUNCTION_PARAMETERS_WRAP" value="5" />
<option name="FUNCTION_CALL_ARGUMENTS_WRAP" value="5" />
<option name="TEMPLATE_CALL_ARGUMENTS_WRAP" value="5" />
<option name="TEMPLATE_CALL_ARGUMENTS_ALIGN_MULTILINE" value="true" />
<option name="CLASS_CONSTRUCTOR_INIT_LIST_WRAP" value="5" />
<option name="ALIGN_INIT_LIST_IN_COLUMNS" value="false" />
<option name="SPACE_BEFORE_PROTOCOLS_BRACKETS" value="false" />
<option name="SPACE_BEFORE_POINTER_IN_DECLARATION" value="false" />
<option name="SPACE_AFTER_POINTER_IN_DECLARATION" value="true" />
<option name="SPACE_BEFORE_REFERENCE_IN_DECLARATION" value="false" />
<option name="SPACE_AFTER_REFERENCE_IN_DECLARATION" value="true" />
<option name="KEEP_BLANK_LINES_BEFORE_END" value="1" />
</Objective-C>
<codeStyleSettings language="ObjectiveC">
<option name="KEEP_BLANK_LINES_IN_DECLARATIONS" value="1" />
<option name="KEEP_BLANK_LINES_IN_CODE" value="1" />
<option name="KEEP_BLANK_LINES_BEFORE_RBRACE" value="1" />
<option name="BLANK_LINES_AROUND_CLASS" value="0" />
<option name="BLANK_LINES_AROUND_METHOD_IN_INTERFACE" value="0" />
<option name="BLANK_LINES_AFTER_CLASS_HEADER" value="1" />
<option name="SPACE_AFTER_TYPE_CAST" value="false" />
<option name="BINARY_OPERATION_SIGN_ON_NEXT_LINE" value="true" />
<option name="KEEP_SIMPLE_BLOCKS_IN_ONE_LINE" value="false" />
<option name="FOR_STATEMENT_WRAP" value="1" />
<option name="ASSIGNMENT_WRAP" value="1" />
<indentOptions>
<option name="CONTINUATION_INDENT_SIZE" value="4" />
</indentOptions>
</codeStyleSettings>
</code_scheme>
\ No newline at end of file
*cmake-build-debug*
*cmake-build-release*
*cmake_build*
*src/thirdparty*
*src/core/thirdparty*
*src/grpc*
*thirdparty*
*easylogging++*
*SqliteMetaImpl.cpp
\ No newline at end of file
*SqliteMetaImpl.cpp
*src/grpc*
*milvus/include*
\ No newline at end of file
......@@ -99,21 +99,26 @@ if [[ ${RUN_CPPLINT} == "ON" ]]; then
# cpplint check
make lint
if [ $? -ne 0 ]; then
echo "ERROR! cpplint check not pass"
echo "ERROR! cpplint check failed"
exit 1
fi
echo "cpplint check passed!"
# clang-format check
make check-clang-format
if [ $? -ne 0 ]; then
echo "ERROR! clang-format check failed"
exit 1
fi
# clang-tidy check
make check-clang-tidy
if [ $? -ne 0 ]; then
echo "ERROR! clang-tidy check failed"
exit 1
fi
echo "clang-format check passed!"
# # clang-tidy check
# make check-clang-tidy
# if [ $? -ne 0 ]; then
# echo "ERROR! clang-tidy check failed"
# exit 1
# fi
# echo "clang-tidy check passed!"
else
# compile and build
make -j 4 || exit 1
......
......@@ -11,25 +11,30 @@ db_config:
secondary_path: # path used to store data only, split by semicolon
backend_url: sqlite://:@:/ # URI format: dialect://username:password@host:port/database
# Keep 'dialect://:@:/', and replace other texts with real values.
# Keep 'dialect://:@:/', and replace other texts with real values
# Replace 'dialect' with 'mysql' or 'sqlite'
insert_buffer_size: 4 # GB, maximum insert buffer size allowed
# sum of insert_buffer_size and cpu_cache_capacity cannot exceed total memory
build_index_gpu: 0 # gpu id used for building index
preload_table: # preload data at startup, '*' means load all tables, empty value means no preload
# you can specify preload tables like this: table1,table2,table3
metric_config:
enable_monitor: false # enable monitoring or not
collector: prometheus # prometheus
prometheus_config:
port: 8080 # port prometheus used to fetch metrics
port: 8080 # port prometheus uses to fetch metrics
cache_config:
cpu_mem_capacity: 16 # GB, CPU memory used for cache
cpu_mem_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered
cpu_cache_capacity: 16 # GB, CPU memory used for cache
cpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered
cache_insert_data: false # whether to load inserted data into cache
engine_config:
blas_threshold: 20
use_blas_threshold: 20 # if nq < use_blas_threshold, use SSE, faster with fluctuated response times
# if nq >= use_blas_threshold, use OpenBlas, slower with stable response times
resource_config:
resource_pool:
......
......@@ -70,11 +70,11 @@ fi
for test in `ls ${DIR_UNITTEST}`; do
echo $test
case ${test} in
db_test)
# set run args for db_test
test_db)
# set run args for test_db
args="mysql://${MYSQL_USER_NAME}:${MYSQL_PASSWORD}@${MYSQL_HOST}:${MYSQL_PORT}/${MYSQL_DB_NAME}"
;;
*_test)
*test_*)
args=""
;;
esac
......@@ -104,7 +104,7 @@ ${LCOV_CMD} -r "${FILE_INFO_OUTPUT}" -o "${FILE_INFO_OUTPUT_NEW}" \
"src/metrics/MetricBase.h"\
"src/server/Server.cpp"\
"src/server/DBWrapper.cpp"\
"src/server/grpc_impl/GrpcMilvusServer.cpp"\
"src/server/grpc_impl/GrpcServer.cpp"\
"src/utils/easylogging++.h"\
"src/utils/easylogging++.cc"\
......
......@@ -20,29 +20,24 @@
include_directories(${MILVUS_SOURCE_DIR})
include_directories(${MILVUS_ENGINE_SRC})
add_subdirectory(core)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-status)
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus)
#this statement must put here, since the CORE_INCLUDE_DIRS is defined in code/CMakeList.txt
add_subdirectory(core)
set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
foreach (dir ${CORE_INCLUDE_DIRS})
include_directories(${dir})
endforeach ()
aux_source_directory(${MILVUS_ENGINE_SRC}/cache cache_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/config config_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db db_main_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/engine db_engine_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/insert db_insert_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler db_scheduler_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/context db_scheduler_context_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/task db_scheduler_task_files)
set(db_scheduler_files
${db_scheduler_files}
${db_scheduler_context_files}
${db_scheduler_task_files}
)
set(grpc_service_files
${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.grpc.pb.cc
......@@ -51,12 +46,11 @@ set(grpc_service_files
${MILVUS_ENGINE_SRC}/grpc/gen-status/status.pb.cc
)
aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler scheduler_main_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/action scheduler_action_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/event scheduler_event_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/job scheduler_job_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/optimizer scheduler_optimizer_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/resource scheduler_resource_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/task scheduler_task_files)
set(scheduler_files
......@@ -64,15 +58,14 @@ set(scheduler_files
${scheduler_action_files}
${scheduler_event_files}
${scheduler_job_files}
${scheduler_optimizer_files}
${scheduler_resource_files}
${scheduler_task_files}
)
aux_source_directory(${MILVUS_ENGINE_SRC}/server server_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/server/grpc_impl grpc_server_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/utils utils_files)
aux_source_directory(${MILVUS_ENGINE_SRC}/wrapper wrapper_files)
set(engine_files
......@@ -82,16 +75,11 @@ set(engine_files
${db_engine_files}
${db_insert_files}
${db_meta_files}
${db_scheduler_files}
${metrics_files}
${utils_files}
${wrapper_files}
)
include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include")
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-status)
include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus)
set(client_grpc_lib
grpcpp_channelz
grpc++
......@@ -112,6 +100,12 @@ set(boost_lib
boost_serialization_static
)
set(cuda_lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
cudart
cublas
)
set(third_party_libs
sqlite
${client_grpc_lib}
......@@ -123,17 +117,15 @@ set(third_party_libs
snappy
zlib
zstd
cudart
cublas
${cuda_lib}
mysqlpp
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
cudart
)
if (MILVUS_ENABLE_PROFILING STREQUAL "ON")
set(third_party_libs ${third_party_libs}
gperftools
libunwind
)
gperftools
libunwind
)
endif ()
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
......@@ -141,7 +133,6 @@ set(engine_libs
pthread
libgomp.a
libgfortran.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
)
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
......@@ -152,7 +143,11 @@ if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
endif ()
cuda_add_library(milvus_engine STATIC ${engine_files})
target_link_libraries(milvus_engine ${engine_libs} knowhere ${third_party_libs})
target_link_libraries(milvus_engine
knowhere
${engine_libs}
${third_party_libs}
)
add_library(metrics STATIC ${metrics_files})
......@@ -180,7 +175,9 @@ add_executable(milvus_server
${utils_files}
)
target_link_libraries(milvus_server ${server_libs})
target_link_libraries(milvus_server
${server_libs}
)
install(TARGETS milvus_server DESTINATION bin)
......
......@@ -15,55 +15,66 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "LRU.h"
#include "utils/Log.h"
#include <string>
#include <mutex>
#include <atomic>
#include <mutex>
#include <set>
#include <string>
namespace zilliz {
namespace milvus {
namespace cache {
template<typename ItemObj>
template <typename ItemObj>
class Cache {
public:
//mem_capacity, units:GB
// mem_capacity, units:GB
Cache(int64_t capacity_gb, uint64_t cache_max_count);
~Cache() = default;
int64_t usage() const {
int64_t
usage() const {
return usage_;
}
int64_t capacity() const {
int64_t
capacity() const {
return capacity_;
} //unit: BYTE
void set_capacity(int64_t capacity); //unit: BYTE
} // unit: BYTE
void
set_capacity(int64_t capacity); // unit: BYTE
double freemem_percent() const {
double
freemem_percent() const {
return freemem_percent_;
}
void set_freemem_percent(double percent) {
void
set_freemem_percent(double percent) {
freemem_percent_ = percent;
}
size_t size() const;
bool exists(const std::string &key);
ItemObj get(const std::string &key);
void insert(const std::string &key, const ItemObj &item);
void erase(const std::string &key);
void print();
void clear();
size_t
size() const;
bool
exists(const std::string& key);
ItemObj
get(const std::string& key);
void
insert(const std::string& key, const ItemObj& item);
void
erase(const std::string& key);
void
print();
void
clear();
private:
void free_memory();
void
free_memory();
private:
int64_t usage_;
......@@ -74,8 +85,7 @@ class Cache {
mutable std::mutex mutex_;
};
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
#include "cache/Cache.inl"
......@@ -17,7 +17,7 @@
namespace zilliz {
namespace milvus {
namespace cache {
......@@ -190,5 +190,5 @@ Cache<ItemObj>::print() {
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -15,40 +15,48 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "Cache.h"
#include "utils/Log.h"
#include "metrics/Metrics.h"
#include "utils/Log.h"
#include <string>
#include <memory>
#include <string>
namespace zilliz {
namespace milvus {
namespace cache {
template<typename ItemObj>
template <typename ItemObj>
class CacheMgr {
public:
virtual uint64_t ItemCount() const;
virtual uint64_t
ItemCount() const;
virtual bool ItemExists(const std::string &key);
virtual bool
ItemExists(const std::string& key);
virtual ItemObj GetItem(const std::string &key);
virtual ItemObj
GetItem(const std::string& key);
virtual void InsertItem(const std::string &key, const ItemObj &data);
virtual void
InsertItem(const std::string& key, const ItemObj& data);
virtual void EraseItem(const std::string &key);
virtual void
EraseItem(const std::string& key);
virtual void PrintInfo();
virtual void
PrintInfo();
virtual void ClearCache();
virtual void
ClearCache();
int64_t CacheUsage() const;
int64_t CacheCapacity() const;
void SetCapacity(int64_t capacity);
int64_t
CacheUsage() const;
int64_t
CacheCapacity() const;
void
SetCapacity(int64_t capacity);
protected:
CacheMgr();
......@@ -59,8 +67,7 @@ class CacheMgr {
CachePtr cache_;
};
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
#include "cache/CacheMgr.inl"
......@@ -16,7 +16,7 @@
// under the License.
namespace zilliz {
namespace milvus {
namespace cache {
......@@ -142,4 +142,4 @@ CacheMgr<ItemObj>::SetCapacity(int64_t capacity) {
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -15,14 +15,12 @@
// specific language governing permissions and limitations
// under the License.
#include "cache/CpuCacheMgr.h"
#include "server/Config.h"
#include "utils/Log.h"
#include <utility>
namespace zilliz {
namespace milvus {
namespace cache {
......@@ -31,38 +29,38 @@ constexpr int64_t unit = 1024 * 1024 * 1024;
}
CpuCacheMgr::CpuCacheMgr() {
server::Config &config = server::Config::GetInstance();
server::Config& config = server::Config::GetInstance();
Status s;
int32_t cpu_mem_cap;
s = config.GetCacheConfigCpuMemCapacity(cpu_mem_cap);
int32_t cpu_cache_cap;
s = config.GetCacheConfigCpuCacheCapacity(cpu_cache_cap);
if (!s.ok()) {
SERVER_LOG_ERROR << s.message();
}
int64_t cap = cpu_mem_cap * unit;
int64_t cap = cpu_cache_cap * unit;
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32);
float cpu_mem_threshold;
s = config.GetCacheConfigCpuMemThreshold(cpu_mem_threshold);
float cpu_cache_threshold;
s = config.GetCacheConfigCpuCacheThreshold(cpu_cache_threshold);
if (!s.ok()) {
SERVER_LOG_ERROR << s.message();
}
if (cpu_mem_threshold > 0.0 && cpu_mem_threshold <= 1.0) {
cache_->set_freemem_percent(cpu_mem_threshold);
if (cpu_cache_threshold > 0.0 && cpu_cache_threshold <= 1.0) {
cache_->set_freemem_percent(cpu_cache_threshold);
} else {
SERVER_LOG_ERROR << "Invalid cpu_mem_threshold: " << cpu_mem_threshold
<< ", by default set to " << cache_->freemem_percent();
SERVER_LOG_ERROR << "Invalid cpu_cache_threshold: " << cpu_cache_threshold << ", by default set to "
<< cache_->freemem_percent();
}
}
CpuCacheMgr *
CpuCacheMgr*
CpuCacheMgr::GetInstance() {
static CpuCacheMgr s_mgr;
return &s_mgr;
}
engine::VecIndexPtr
CpuCacheMgr::GetIndex(const std::string &key) {
CpuCacheMgr::GetIndex(const std::string& key) {
DataObjPtr obj = GetItem(key);
if (obj != nullptr) {
return obj->data();
......@@ -71,6 +69,5 @@ CpuCacheMgr::GetIndex(const std::string &key) {
return nullptr;
}
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
......@@ -20,10 +20,9 @@
#include "CacheMgr.h"
#include "DataObj.h"
#include <string>
#include <memory>
#include <string>
namespace zilliz {
namespace milvus {
namespace cache {
......@@ -32,12 +31,13 @@ class CpuCacheMgr : public CacheMgr<DataObjPtr> {
CpuCacheMgr();
public:
//TODO: use smart pointer instead
static CpuCacheMgr *GetInstance();
// TODO(myh): use smart pointer instead
static CpuCacheMgr*
GetInstance();
engine::VecIndexPtr GetIndex(const std::string &key);
engine::VecIndexPtr
GetIndex(const std::string& key);
};
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
......@@ -15,37 +15,35 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "src/wrapper/VecIndex.h"
#include <memory>
namespace zilliz {
namespace milvus {
namespace cache {
class DataObj {
public:
explicit DataObj(const engine::VecIndexPtr &index)
: index_(index) {
explicit DataObj(const engine::VecIndexPtr& index) : index_(index) {
}
DataObj(const engine::VecIndexPtr &index, int64_t size)
: index_(index),
size_(size) {
DataObj(const engine::VecIndexPtr& index, int64_t size) : index_(index), size_(size) {
}
engine::VecIndexPtr data() {
engine::VecIndexPtr
data() {
return index_;
}
const engine::VecIndexPtr &data() const {
const engine::VecIndexPtr&
data() const {
return index_;
}
int64_t size() const {
int64_t
size() const {
if (index_ == nullptr) {
return 0;
}
......@@ -64,6 +62,5 @@ class DataObj {
using DataObjPtr = std::shared_ptr<DataObj>;
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
......@@ -15,15 +15,13 @@
// specific language governing permissions and limitations
// under the License.
#include "cache/GpuCacheMgr.h"
#include "utils/Log.h"
#include "server/Config.h"
#include "utils/Log.h"
#include <sstream>
#include <utility>
namespace zilliz {
namespace milvus {
namespace cache {
......@@ -35,31 +33,31 @@ constexpr int64_t G_BYTE = 1024 * 1024 * 1024;
}
GpuCacheMgr::GpuCacheMgr() {
server::Config &config = server::Config::GetInstance();
server::Config& config = server::Config::GetInstance();
Status s;
int32_t gpu_mem_cap;
s = config.GetCacheConfigGpuMemCapacity(gpu_mem_cap);
int32_t gpu_cache_cap;
s = config.GetCacheConfigGpuCacheCapacity(gpu_cache_cap);
if (!s.ok()) {
SERVER_LOG_ERROR << s.message();
}
int32_t cap = gpu_mem_cap * G_BYTE;
int32_t cap = gpu_cache_cap * G_BYTE;
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32);
float gpu_mem_threshold;
s = config.GetCacheConfigGpuMemThreshold(gpu_mem_threshold);
s = config.GetCacheConfigGpuCacheThreshold(gpu_mem_threshold);
if (!s.ok()) {
SERVER_LOG_ERROR << s.message();
}
if (gpu_mem_threshold > 0.0 && gpu_mem_threshold <= 1.0) {
cache_->set_freemem_percent(gpu_mem_threshold);
} else {
SERVER_LOG_ERROR << "Invalid gpu_mem_threshold: " << gpu_mem_threshold
<< ", by default set to " << cache_->freemem_percent();
SERVER_LOG_ERROR << "Invalid gpu_mem_threshold: " << gpu_mem_threshold << ", by default set to "
<< cache_->freemem_percent();
}
}
GpuCacheMgr *
GpuCacheMgr*
GpuCacheMgr::GetInstance(uint64_t gpu_id) {
if (instance_.find(gpu_id) == instance_.end()) {
std::lock_guard<std::mutex> lock(mutex_);
......@@ -74,7 +72,7 @@ GpuCacheMgr::GetInstance(uint64_t gpu_id) {
}
engine::VecIndexPtr
GpuCacheMgr::GetIndex(const std::string &key) {
GpuCacheMgr::GetIndex(const std::string& key) {
DataObjPtr obj = GetItem(key);
if (obj != nullptr) {
return obj->data();
......@@ -83,6 +81,5 @@ GpuCacheMgr::GetIndex(const std::string &key) {
return nullptr;
}
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
......@@ -15,15 +15,13 @@
// specific language governing permissions and limitations
// under the License.
#include "CacheMgr.h"
#include "DataObj.h"
#include <unordered_map>
#include <memory>
#include <string>
#include <unordered_map>
namespace zilliz {
namespace milvus {
namespace cache {
......@@ -34,15 +32,16 @@ class GpuCacheMgr : public CacheMgr<DataObjPtr> {
public:
GpuCacheMgr();
static GpuCacheMgr *GetInstance(uint64_t gpu_id);
static GpuCacheMgr*
GetInstance(uint64_t gpu_id);
engine::VecIndexPtr GetIndex(const std::string &key);
engine::VecIndexPtr
GetIndex(const std::string& key);
private:
static std::mutex mutex_;
static std::unordered_map<uint64_t, GpuCacheMgrPtr> instance_;
};
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
......@@ -15,20 +15,18 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <unordered_map>
#include <list>
#include <cstddef>
#include <list>
#include <stdexcept>
#include <unordered_map>
#include <utility>
namespace zilliz {
namespace milvus {
namespace cache {
template<typename key_t, typename value_t>
template <typename key_t, typename value_t>
class LRU {
public:
typedef typename std::pair<key_t, value_t> key_value_pair_t;
......@@ -38,7 +36,8 @@ class LRU {
explicit LRU(size_t max_size) : max_size_(max_size) {
}
void put(const key_t &key, const value_t &value) {
void
put(const key_t& key, const value_t& value) {
auto it = cache_items_map_.find(key);
cache_items_list_.push_front(key_value_pair_t(key, value));
if (it != cache_items_map_.end()) {
......@@ -55,7 +54,8 @@ class LRU {
}
}
const value_t &get(const key_t &key) {
const value_t&
get(const key_t& key) {
auto it = cache_items_map_.find(key);
if (it == cache_items_map_.end()) {
throw std::range_error("There is no such key in cache");
......@@ -65,7 +65,8 @@ class LRU {
}
}
void erase(const key_t &key) {
void
erase(const key_t& key) {
auto it = cache_items_map_.find(key);
if (it != cache_items_map_.end()) {
cache_items_list_.erase(it->second);
......@@ -73,32 +74,39 @@ class LRU {
}
}
bool exists(const key_t &key) const {
bool
exists(const key_t& key) const {
return cache_items_map_.find(key) != cache_items_map_.end();
}
size_t size() const {
size_t
size() const {
return cache_items_map_.size();
}
list_iterator_t begin() {
list_iterator_t
begin() {
iter_ = cache_items_list_.begin();
return iter_;
}
list_iterator_t end() {
list_iterator_t
end() {
return cache_items_list_.end();
}
reverse_list_iterator_t rbegin() {
reverse_list_iterator_t
rbegin() {
return cache_items_list_.rbegin();
}
reverse_list_iterator_t rend() {
reverse_list_iterator_t
rend() {
return cache_items_list_.rend();
}
void clear() {
void
clear() {
cache_items_list_.clear();
cache_items_map_.clear();
}
......@@ -110,7 +118,5 @@ class LRU {
list_iterator_t iter_;
};
} // namespace cache
} // namespace milvus
} // namespace zilliz
} // namespace cache
} // namespace milvus
......@@ -18,16 +18,14 @@
#include "config/ConfigMgr.h"
#include "YamlConfigMgr.h"
namespace zilliz {
namespace milvus {
namespace server {
ConfigMgr *
ConfigMgr*
ConfigMgr::GetInstance() {
static YamlConfigMgr mgr;
return &mgr;
}
} // namespace server
} // namespace milvus
} // namespace zilliz
} // namespace server
} // namespace milvus
......@@ -17,12 +17,11 @@
#pragma once
#include "utils/Error.h"
#include "ConfigNode.h"
#include "utils/Error.h"
#include <string>
namespace zilliz {
namespace milvus {
namespace server {
......@@ -42,16 +41,21 @@ namespace server {
class ConfigMgr {
public:
static ConfigMgr *GetInstance();
virtual ErrorCode LoadConfigFile(const std::string &filename) = 0;
virtual void Print() const = 0;//will be deleted
virtual std::string DumpString() const = 0;
virtual const ConfigNode &GetRootNode() const = 0;
virtual ConfigNode &GetRootNode() = 0;
static ConfigMgr*
GetInstance();
virtual ErrorCode
LoadConfigFile(const std::string& filename) = 0;
virtual void
Print() const = 0; // will be deleted
virtual std::string
DumpString() const = 0;
virtual const ConfigNode&
GetRootNode() const = 0;
virtual ConfigNode&
GetRootNode() = 0;
};
} // namespace server
} // namespace milvus
} // namespace zilliz
} // namespace server
} // namespace milvus
......@@ -19,51 +19,50 @@
#include "utils/Error.h"
#include "utils/Log.h"
#include <algorithm>
#include <sstream>
#include <string>
#include <algorithm>
namespace zilliz {
namespace milvus {
namespace server {
void
ConfigNode::Combine(const ConfigNode &target) {
const std::map<std::string, std::string> &kv = target.GetConfig();
ConfigNode::Combine(const ConfigNode& target) {
const std::map<std::string, std::string>& kv = target.GetConfig();
for (auto itr = kv.begin(); itr != kv.end(); ++itr) {
config_[itr->first] = itr->second;
}
const std::map<std::string, std::vector<std::string> > &sequences = target.GetSequences();
const std::map<std::string, std::vector<std::string> >& sequences = target.GetSequences();
for (auto itr = sequences.begin(); itr != sequences.end(); ++itr) {
sequences_[itr->first] = itr->second;
}
const std::map<std::string, ConfigNode> &children = target.GetChildren();
const std::map<std::string, ConfigNode>& children = target.GetChildren();
for (auto itr = children.begin(); itr != children.end(); ++itr) {
children_[itr->first] = itr->second;
}
}
//key/value pair config
// key/value pair config
void
ConfigNode::SetValue(const std::string &key, const std::string &value) {
ConfigNode::SetValue(const std::string& key, const std::string& value) {
config_[key] = value;
}
std::string
ConfigNode::GetValue(const std::string &param_key, const std::string &default_val) const {
ConfigNode::GetValue(const std::string& param_key, const std::string& default_val) const {
auto ref = config_.find(param_key);
if (ref != config_.end()) {
return ref->second;
}
//THROW_UNEXPECTED_ERROR("Can't find parameter key: " + param_key);
// THROW_UNEXPECTED_ERROR("Can't find parameter key: " + param_key);
return default_val;
}
bool
ConfigNode::GetBoolValue(const std::string &param_key, bool default_val) const {
ConfigNode::GetBoolValue(const std::string& param_key, bool default_val) const {
std::string val = GetValue(param_key);
if (!val.empty()) {
std::transform(val.begin(), val.end(), val.begin(), ::tolower);
......@@ -74,17 +73,17 @@ ConfigNode::GetBoolValue(const std::string &param_key, bool default_val) const {
}
int32_t
ConfigNode::GetInt32Value(const std::string &param_key, int32_t default_val) const {
ConfigNode::GetInt32Value(const std::string& param_key, int32_t default_val) const {
std::string val = GetValue(param_key);
if (!val.empty()) {
return (int32_t) std::strtol(val.c_str(), nullptr, 10);
return (int32_t)std::strtol(val.c_str(), nullptr, 10);
} else {
return default_val;
}
}
int64_t
ConfigNode::GetInt64Value(const std::string &param_key, int64_t default_val) const {
ConfigNode::GetInt64Value(const std::string& param_key, int64_t default_val) const {
std::string val = GetValue(param_key);
if (!val.empty()) {
return std::strtol(val.c_str(), nullptr, 10);
......@@ -94,7 +93,7 @@ ConfigNode::GetInt64Value(const std::string &param_key, int64_t default_val) con
}
float
ConfigNode::GetFloatValue(const std::string &param_key, float default_val) const {
ConfigNode::GetFloatValue(const std::string& param_key, float default_val) const {
std::string val = GetValue(param_key);
if (!val.empty()) {
return std::strtof(val.c_str(), nullptr);
......@@ -104,7 +103,7 @@ ConfigNode::GetFloatValue(const std::string &param_key, float default_val) const
}
double
ConfigNode::GetDoubleValue(const std::string &param_key, double default_val) const {
ConfigNode::GetDoubleValue(const std::string& param_key, double default_val) const {
std::string val = GetValue(param_key);
if (!val.empty()) {
return std::strtod(val.c_str(), nullptr);
......@@ -113,7 +112,7 @@ ConfigNode::GetDoubleValue(const std::string &param_key, double default_val) con
}
}
const std::map<std::string, std::string> &
const std::map<std::string, std::string>&
ConfigNode::GetConfig() const {
return config_;
}
......@@ -123,14 +122,14 @@ ConfigNode::ClearConfig() {
config_.clear();
}
//key/object config
// key/object config
void
ConfigNode::AddChild(const std::string &type_name, const ConfigNode &config) {
ConfigNode::AddChild(const std::string& type_name, const ConfigNode& config) {
children_[type_name] = config;
}
ConfigNode
ConfigNode::GetChild(const std::string &type_name) const {
ConfigNode::GetChild(const std::string& type_name) const {
auto ref = children_.find(type_name);
if (ref != children_.end()) {
return ref->second;
......@@ -140,20 +139,20 @@ ConfigNode::GetChild(const std::string &type_name) const {
return nc;
}
ConfigNode &
ConfigNode::GetChild(const std::string &type_name) {
ConfigNode&
ConfigNode::GetChild(const std::string& type_name) {
return children_[type_name];
}
void
ConfigNode::GetChildren(ConfigNodeArr &arr) const {
ConfigNode::GetChildren(ConfigNodeArr& arr) const {
arr.clear();
for (auto ref : children_) {
arr.push_back(ref.second);
}
}
const std::map<std::string, ConfigNode> &
const std::map<std::string, ConfigNode>&
ConfigNode::GetChildren() const {
return children_;
}
......@@ -163,14 +162,14 @@ ConfigNode::ClearChildren() {
children_.clear();
}
//key/sequence config
// key/sequence config
void
ConfigNode::AddSequenceItem(const std::string &key, const std::string &item) {
ConfigNode::AddSequenceItem(const std::string& key, const std::string& item) {
sequences_[key].push_back(item);
}
std::vector<std::string>
ConfigNode::GetSequence(const std::string &key) const {
ConfigNode::GetSequence(const std::string& key) const {
auto itr = sequences_.find(key);
if (itr != sequences_.end()) {
return itr->second;
......@@ -180,7 +179,7 @@ ConfigNode::GetSequence(const std::string &key) const {
}
}
const std::map<std::string, std::vector<std::string> > &
const std::map<std::string, std::vector<std::string> >&
ConfigNode::GetSequences() const {
return sequences_;
}
......@@ -191,40 +190,40 @@ ConfigNode::ClearSequences() {
}
void
ConfigNode::PrintAll(const std::string &prefix) const {
for (auto &elem : config_) {
ConfigNode::PrintAll(const std::string& prefix) const {
for (auto& elem : config_) {
SERVER_LOG_INFO << prefix << elem.first + ": " << elem.second;
}
for (auto &elem : sequences_) {
for (auto& elem : sequences_) {
SERVER_LOG_INFO << prefix << elem.first << ": ";
for (auto &str : elem.second) {
for (auto& str : elem.second) {
SERVER_LOG_INFO << prefix << " - " << str;
}
}
for (auto &elem : children_) {
for (auto& elem : children_) {
SERVER_LOG_INFO << prefix << elem.first << ": ";
elem.second.PrintAll(prefix + " ");
}
}
std::string
ConfigNode::DumpString(const std::string &prefix) const {
ConfigNode::DumpString(const std::string& prefix) const {
std::stringstream str_buffer;
const std::string endl = "\n";
for (auto &elem : config_) {
for (auto& elem : config_) {
str_buffer << prefix << elem.first << ": " << elem.second << endl;
}
for (auto &elem : sequences_) {
for (auto& elem : sequences_) {
str_buffer << prefix << elem.first << ": " << endl;
for (auto &str : elem.second) {
for (auto& str : elem.second) {
str_buffer << prefix + " - " << str << endl;
}
}
for (auto &elem : children_) {
for (auto& elem : children_) {
str_buffer << prefix << elem.first << ": " << endl;
str_buffer << elem.second.DumpString(prefix + " ") << endl;
}
......@@ -232,6 +231,5 @@ ConfigNode::DumpString(const std::string &prefix) const {
return str_buffer.str();
}
} // namespace server
} // namespace milvus
} // namespace zilliz
} // namespace server
} // namespace milvus
......@@ -17,11 +17,10 @@
#pragma once
#include <vector>
#include <string>
#include <map>
#include <string>
#include <vector>
namespace zilliz {
namespace milvus {
namespace server {
......@@ -30,39 +29,61 @@ typedef std::vector<ConfigNode> ConfigNodeArr;
class ConfigNode {
public:
void Combine(const ConfigNode &target);
void
Combine(const ConfigNode& target);
//key/value pair config
void SetValue(const std::string &key, const std::string &value);
// key/value pair config
void
SetValue(const std::string& key, const std::string& value);
std::string GetValue(const std::string &param_key, const std::string &default_val = "") const;
bool GetBoolValue(const std::string &param_key, bool default_val = false) const;
int32_t GetInt32Value(const std::string &param_key, int32_t default_val = 0) const;
int64_t GetInt64Value(const std::string &param_key, int64_t default_val = 0) const;
float GetFloatValue(const std::string &param_key, float default_val = 0.0) const;
double GetDoubleValue(const std::string &param_key, double default_val = 0.0) const;
std::string
GetValue(const std::string& param_key, const std::string& default_val = "") const;
bool
GetBoolValue(const std::string& param_key, bool default_val = false) const;
int32_t
GetInt32Value(const std::string& param_key, int32_t default_val = 0) const;
int64_t
GetInt64Value(const std::string& param_key, int64_t default_val = 0) const;
float
GetFloatValue(const std::string& param_key, float default_val = 0.0) const;
double
GetDoubleValue(const std::string& param_key, double default_val = 0.0) const;
const std::map<std::string, std::string> &GetConfig() const;
void ClearConfig();
const std::map<std::string, std::string>&
GetConfig() const;
void
ClearConfig();
//key/object config
void AddChild(const std::string &type_name, const ConfigNode &config);
ConfigNode GetChild(const std::string &type_name) const;
ConfigNode &GetChild(const std::string &type_name);
void GetChildren(ConfigNodeArr &arr) const;
// key/object config
void
AddChild(const std::string& type_name, const ConfigNode& config);
ConfigNode
GetChild(const std::string& type_name) const;
ConfigNode&
GetChild(const std::string& type_name);
void
GetChildren(ConfigNodeArr& arr) const;
const std::map<std::string, ConfigNode> &GetChildren() const;
void ClearChildren();
const std::map<std::string, ConfigNode>&
GetChildren() const;
void
ClearChildren();
//key/sequence config
void AddSequenceItem(const std::string &key, const std::string &item);
std::vector<std::string> GetSequence(const std::string &key) const;
// key/sequence config
void
AddSequenceItem(const std::string& key, const std::string& item);
std::vector<std::string>
GetSequence(const std::string& key) const;
const std::map<std::string, std::vector<std::string> > &GetSequences() const;
void ClearSequences();
const std::map<std::string, std::vector<std::string> >&
GetSequences() const;
void
ClearSequences();
void PrintAll(const std::string &prefix = "") const;
std::string DumpString(const std::string &prefix = "") const;
void
PrintAll(const std::string& prefix = "") const;
std::string
DumpString(const std::string& prefix = "") const;
private:
std::map<std::string, std::string> config_;
......@@ -70,6 +91,5 @@ class ConfigNode {
std::map<std::string, std::vector<std::string> > sequences_;
};
} // namespace server
} // namespace milvus
} // namespace zilliz
} // namespace server
} // namespace milvus
......@@ -20,12 +20,11 @@
#include <sys/stat.h>
namespace zilliz {
namespace milvus {
namespace server {
ErrorCode
YamlConfigMgr::LoadConfigFile(const std::string &filename) {
YamlConfigMgr::LoadConfigFile(const std::string& filename) {
struct stat directoryStat;
int statOK = stat(filename.c_str(), &directoryStat);
if (statOK != 0) {
......@@ -36,8 +35,7 @@ YamlConfigMgr::LoadConfigFile(const std::string &filename) {
try {
node_ = YAML::LoadFile(filename);
LoadConfigNode(node_, config_);
}
catch (YAML::Exception &e) {
} catch (YAML::Exception& e) {
SERVER_LOG_ERROR << "Failed to load config file: " << std::string(e.what());
return SERVER_UNEXPECTED_ERROR;
}
......@@ -56,20 +54,18 @@ YamlConfigMgr::DumpString() const {
return config_.DumpString("");
}
const ConfigNode &
const ConfigNode&
YamlConfigMgr::GetRootNode() const {
return config_;
}
ConfigNode &
ConfigNode&
YamlConfigMgr::GetRootNode() {
return config_;
}
bool
YamlConfigMgr::SetConfigValue(const YAML::Node &node,
const std::string &key,
ConfigNode &config) {
YamlConfigMgr::SetConfigValue(const YAML::Node& node, const std::string& key, ConfigNode& config) {
if (node[key].IsDefined()) {
config.SetValue(key, node[key].as<std::string>());
return true;
......@@ -78,9 +74,7 @@ YamlConfigMgr::SetConfigValue(const YAML::Node &node,
}
bool
YamlConfigMgr::SetChildConfig(const YAML::Node &node,
const std::string &child_name,
ConfigNode &config) {
YamlConfigMgr::SetChildConfig(const YAML::Node& node, const std::string& child_name, ConfigNode& config) {
if (node[child_name].IsDefined()) {
ConfigNode sub_config;
LoadConfigNode(node[child_name], sub_config);
......@@ -91,9 +85,7 @@ YamlConfigMgr::SetChildConfig(const YAML::Node &node,
}
bool
YamlConfigMgr::SetSequence(const YAML::Node &node,
const std::string &child_name,
ConfigNode &config) {
YamlConfigMgr::SetSequence(const YAML::Node& node, const std::string& child_name, ConfigNode& config) {
if (node[child_name].IsDefined()) {
size_t cnt = node[child_name].size();
for (size_t i = 0; i < cnt; i++) {
......@@ -105,7 +97,7 @@ YamlConfigMgr::SetSequence(const YAML::Node &node,
}
void
YamlConfigMgr::LoadConfigNode(const YAML::Node &node, ConfigNode &config) {
YamlConfigMgr::LoadConfigNode(const YAML::Node& node, ConfigNode& config) {
std::string key;
for (YAML::const_iterator it = node.begin(); it != node.end(); ++it) {
if (!it->first.IsNull()) {
......@@ -121,6 +113,5 @@ YamlConfigMgr::LoadConfigNode(const YAML::Node &node, ConfigNode &config) {
}
}
} // namespace server
} // namespace milvus
} // namespace zilliz
} // namespace server
} // namespace milvus
......@@ -21,43 +21,43 @@
#include "ConfigNode.h"
#include "utils/Error.h"
#include <string>
#include <yaml-cpp/yaml.h>
#include <string>
namespace zilliz {
namespace milvus {
namespace server {
class YamlConfigMgr : public ConfigMgr {
public:
virtual ErrorCode LoadConfigFile(const std::string &filename);
virtual void Print() const;
virtual std::string DumpString() const;
virtual const ConfigNode &GetRootNode() const;
virtual ConfigNode &GetRootNode();
virtual ErrorCode
LoadConfigFile(const std::string& filename);
virtual void
Print() const;
virtual std::string
DumpString() const;
virtual const ConfigNode&
GetRootNode() const;
virtual ConfigNode&
GetRootNode();
private:
bool SetConfigValue(const YAML::Node &node,
const std::string &key,
ConfigNode &config);
bool
SetConfigValue(const YAML::Node& node, const std::string& key, ConfigNode& config);
bool SetChildConfig(const YAML::Node &node,
const std::string &name,
ConfigNode &config);
bool
SetChildConfig(const YAML::Node& node, const std::string& child_name, ConfigNode& config);
bool
SetSequence(const YAML::Node &node,
const std::string &child_name,
ConfigNode &config);
SetSequence(const YAML::Node& node, const std::string& child_name, ConfigNode& config);
void LoadConfigNode(const YAML::Node &node, ConfigNode &config);
void
LoadConfigNode(const YAML::Node& node, ConfigNode& config);
private:
YAML::Node node_;
ConfigNode config_;
};
} // namespace server
} // namespace milvus
} // namespace zilliz
} // namespace server
} // namespace milvus
......@@ -46,9 +46,11 @@ if(NOT CMAKE_BUILD_TYPE)
endif(NOT CMAKE_BUILD_TYPE)
if(CMAKE_BUILD_TYPE STREQUAL "Release")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -fopenmp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -fopenmp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g")
endif()
MESSAGE(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
......@@ -93,7 +95,7 @@ endif()
set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
if(BUILD_UNIT_TEST STREQUAL "ON")
add_subdirectory(test)
add_subdirectory(unittest)
endif()
config_summary()
文件模式从 100755 更改为 100644
......@@ -15,42 +15,39 @@
// specific language governing permissions and limitations
// under the License.
#include "knowhere/adapter/ArrowAdapter.h"
#include "ArrowAdapter.h"
namespace zilliz {
namespace knowhere {
ArrayPtr
CopyArray(const ArrayPtr &origin) {
CopyArray(const ArrayPtr& origin) {
ArrayPtr copy = nullptr;
auto copy_data = origin->data()->Copy();
switch (origin->type_id()) {
#define DEFINE_TYPE(type, clazz) \
case arrow::Type::type: { \
copy = std::make_shared<arrow::clazz>(copy_data); \
}
#define DEFINE_TYPE(type, clazz) \
case arrow::Type::type: { \
copy = std::make_shared<arrow::clazz>(copy_data); \
}
DEFINE_TYPE(BOOL, BooleanArray)
DEFINE_TYPE(BINARY, BinaryArray)
DEFINE_TYPE(FIXED_SIZE_BINARY, FixedSizeBinaryArray)
DEFINE_TYPE(DECIMAL, Decimal128Array)
DEFINE_TYPE(FLOAT, NumericArray<arrow::FloatType>)
DEFINE_TYPE(INT64, NumericArray<arrow::Int64Type>)
default:break;
default:
break;
}
return copy;
}
SchemaPtr
CopySchema(const SchemaPtr &origin) {
CopySchema(const SchemaPtr& origin) {
std::vector<std::shared_ptr<Field>> fields;
for (auto &field : origin->fields()) {
auto copy = std::make_shared<Field>(field->name(), field->type(),field->nullable(), nullptr);
for (auto& field : origin->fields()) {
auto copy = std::make_shared<Field>(field->name(), field->type(), field->nullable(), nullptr);
fields.emplace_back(copy);
}
return std::make_shared<Schema>(std::move(fields));
}
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,22 +15,20 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "knowhere/common/Array.h"
namespace zilliz {
namespace knowhere {
ArrayPtr
CopyArray(const ArrayPtr &origin);
CopyArray(const ArrayPtr& origin);
SchemaPtr
CopySchema(const SchemaPtr &origin);
CopySchema(const SchemaPtr& origin);
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,36 +15,30 @@
// specific language governing permissions and limitations
// under the License.
#include "knowhere/adapter/SptagAdapter.h"
#include "knowhere/adapter/Structure.h"
#include "knowhere/index/vector_index/helpers/Definitions.h"
#include "SptagAdapter.h"
#include "Structure.h"
namespace zilliz {
namespace knowhere {
std::shared_ptr<SPTAG::MetadataSet>
ConvertToMetadataSet(const DatasetPtr &dataset) {
ConvertToMetadataSet(const DatasetPtr& dataset) {
auto array = dataset->array()[0];
auto elems = array->length();
auto p_data = array->data()->GetValues<int64_t>(1, 0);
auto p_offset = (int64_t *) malloc(sizeof(int64_t) * elems);
for (auto i = 0; i <= elems; ++i)
p_offset[i] = i * 8;
std::shared_ptr<SPTAG::MetadataSet> metaset(new SPTAG::MemMetadataSet(
SPTAG::ByteArray((std::uint8_t *) p_data, elems * sizeof(int64_t), false),
SPTAG::ByteArray((std::uint8_t *) p_offset, elems * sizeof(int64_t), true),
elems));
auto p_offset = (int64_t*)malloc(sizeof(int64_t) * elems);
for (auto i = 0; i <= elems; ++i) p_offset[i] = i * 8;
std::shared_ptr<SPTAG::MetadataSet> metaset(
new SPTAG::MemMetadataSet(SPTAG::ByteArray((std::uint8_t*)p_data, elems * sizeof(int64_t), false),
SPTAG::ByteArray((std::uint8_t*)p_offset, elems * sizeof(int64_t), true), elems));
return metaset;
}
std::shared_ptr<SPTAG::VectorSet>
ConvertToVectorSet(const DatasetPtr &dataset) {
ConvertToVectorSet(const DatasetPtr& dataset) {
auto tensor = dataset->tensor()[0];
auto p_data = tensor->raw_mutable_data();
......@@ -54,18 +48,16 @@ ConvertToVectorSet(const DatasetPtr &dataset) {
SPTAG::ByteArray byte_array(p_data, num_bytes, false);
auto vectorset = std::make_shared<SPTAG::BasicVectorSet>(byte_array,
SPTAG::VectorValueType::Float,
dimension,
rows);
auto vectorset =
std::make_shared<SPTAG::BasicVectorSet>(byte_array, SPTAG::VectorValueType::Float, dimension, rows);
return vectorset;
}
std::vector<SPTAG::QueryResult>
ConvertToQueryResult(const DatasetPtr &dataset, const Config &config) {
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config) {
auto tensor = dataset->tensor()[0];
auto p_data = (float *) tensor->raw_mutable_data();
auto p_data = (float*)tensor->raw_mutable_data();
auto dimension = tensor->shape()[1];
auto rows = tensor->shape()[0];
......@@ -82,8 +74,8 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
auto k = query_results[0].GetResultNum();
auto elems = query_results.size() * k;
auto p_id = (int64_t *) malloc(sizeof(int64_t) * elems);
auto p_dist = (float *) malloc(sizeof(float) * elems);
auto p_id = (int64_t*)malloc(sizeof(int64_t) * elems);
auto p_dist = (float*)malloc(sizeof(float) * elems);
// TODO: throw if malloc failed.
#pragma omp parallel for
......@@ -91,14 +83,14 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
auto results = query_results[i].GetResults();
auto num_result = query_results[i].GetResultNum();
for (auto j = 0; j < num_result; ++j) {
// p_id[i * k + j] = results[j].VID;
p_id[i * k + j] = *(int64_t *) query_results[i].GetMetadata(j).Data();
// p_id[i * k + j] = results[j].VID;
p_id[i * k + j] = *(int64_t*)query_results[i].GetMetadata(j).Data();
p_dist[i * k + j] = results[j].Dist;
}
}
auto id_buf = MakeMutableBufferSmart((uint8_t *) p_id, sizeof(int64_t) * elems);
auto dist_buf = MakeMutableBufferSmart((uint8_t *) p_dist, sizeof(float) * elems);
auto id_buf = MakeMutableBufferSmart((uint8_t*)p_id, sizeof(int64_t) * elems);
auto dist_buf = MakeMutableBufferSmart((uint8_t*)p_dist, sizeof(float) * elems);
// TODO: magic
std::vector<BufferPtr> id_bufs{nullptr, id_buf};
......@@ -109,11 +101,11 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
auto id_array_data = arrow::ArrayData::Make(int64_type, elems, id_bufs);
auto dist_array_data = arrow::ArrayData::Make(float_type, elems, dist_bufs);
// auto id_array_data = std::make_shared<ArrayData>(int64_type, sizeof(int64_t) * elems, id_bufs);
// auto dist_array_data = std::make_shared<ArrayData>(float_type, sizeof(float) * elems, dist_bufs);
// auto id_array_data = std::make_shared<ArrayData>(int64_type, sizeof(int64_t) * elems, id_bufs);
// auto dist_array_data = std::make_shared<ArrayData>(float_type, sizeof(float) * elems, dist_bufs);
// auto ids = ConstructInt64Array((uint8_t*)p_id, sizeof(int64_t) * elems);
// auto dists = ConstructFloatArray((uint8_t*)p_dist, sizeof(float) * elems);
// auto ids = ConstructInt64Array((uint8_t*)p_id, sizeof(int64_t) * elems);
// auto dists = ConstructFloatArray((uint8_t*)p_dist, sizeof(float) * elems);
auto ids = std::make_shared<NumericArray<arrow::Int64Type>>(id_array_data);
auto dists = std::make_shared<NumericArray<arrow::FloatType>>(dist_array_data);
......@@ -127,5 +119,4 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results) {
return std::make_shared<Dataset>(array, schema);
}
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,29 +15,26 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <SPTAG/AnnService/inc/Core/VectorIndex.h>
#include <memory>
#include <vector>
#include "knowhere/common/Dataset.h"
namespace zilliz {
namespace knowhere {
std::shared_ptr<SPTAG::VectorSet>
ConvertToVectorSet(const DatasetPtr &dataset);
ConvertToVectorSet(const DatasetPtr& dataset);
std::shared_ptr<SPTAG::MetadataSet>
ConvertToMetadataSet(const DatasetPtr &dataset);
ConvertToMetadataSet(const DatasetPtr& dataset);
std::vector<SPTAG::QueryResult>
ConvertToQueryResult(const DatasetPtr &dataset, const Config &config);
ConvertToQueryResult(const DatasetPtr& dataset, const Config& config);
DatasetPtr
ConvertToDataset(std::vector<SPTAG::QueryResult> query_results);
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,15 +15,15 @@
// specific language governing permissions and limitations
// under the License.
#include "knowhere/adapter/Structure.h"
#include "Structure.h"
#include <string>
#include <vector>
namespace zilliz {
namespace knowhere {
ArrayPtr
ConstructInt64ArraySmart(uint8_t *data, int64_t size) {
ConstructInt64ArraySmart(uint8_t* data, int64_t size) {
// TODO: magic
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBufferSmart(data, size)};
auto type = std::make_shared<arrow::Int64Type>();
......@@ -32,7 +32,7 @@ ConstructInt64ArraySmart(uint8_t *data, int64_t size) {
}
ArrayPtr
ConstructFloatArraySmart(uint8_t *data, int64_t size) {
ConstructFloatArraySmart(uint8_t* data, int64_t size) {
// TODO: magic
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBufferSmart(data, size)};
auto type = std::make_shared<arrow::FloatType>();
......@@ -41,14 +41,14 @@ ConstructFloatArraySmart(uint8_t *data, int64_t size) {
}
TensorPtr
ConstructFloatTensorSmart(uint8_t *data, int64_t size, std::vector<int64_t> shape) {
ConstructFloatTensorSmart(uint8_t* data, int64_t size, std::vector<int64_t> shape) {
auto buffer = MakeMutableBufferSmart(data, size);
auto float_type = std::make_shared<arrow::FloatType>();
return std::make_shared<Tensor>(float_type, buffer, shape);
}
ArrayPtr
ConstructInt64Array(uint8_t *data, int64_t size) {
ConstructInt64Array(uint8_t* data, int64_t size) {
// TODO: magic
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBuffer(data, size)};
auto type = std::make_shared<arrow::Int64Type>();
......@@ -57,7 +57,7 @@ ConstructInt64Array(uint8_t *data, int64_t size) {
}
ArrayPtr
ConstructFloatArray(uint8_t *data, int64_t size) {
ConstructFloatArray(uint8_t* data, int64_t size) {
// TODO: magic
std::vector<BufferPtr> id_buf{nullptr, MakeMutableBuffer(data, size)};
auto type = std::make_shared<arrow::FloatType>();
......@@ -66,23 +66,22 @@ ConstructFloatArray(uint8_t *data, int64_t size) {
}
TensorPtr
ConstructFloatTensor(uint8_t *data, int64_t size, std::vector<int64_t> shape) {
ConstructFloatTensor(uint8_t* data, int64_t size, std::vector<int64_t> shape) {
auto buffer = MakeMutableBuffer(data, size);
auto float_type = std::make_shared<arrow::FloatType>();
return std::make_shared<Tensor>(float_type, buffer, shape);
}
FieldPtr
ConstructInt64Field(const std::string &name) {
ConstructInt64Field(const std::string& name) {
auto type = std::make_shared<arrow::Int64Type>();
return std::make_shared<Field>(name, type);
}
FieldPtr
ConstructFloatField(const std::string &name) {
ConstructFloatField(const std::string& name) {
auto type = std::make_shared<arrow::FloatType>();
return std::make_shared<Field>(name, type);
}
}
}
} // namespace knowhere
......@@ -15,40 +15,38 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "knowhere/common/Dataset.h"
#include <string>
#include <vector>
#include "knowhere/common/Dataset.h"
namespace zilliz {
namespace knowhere {
extern ArrayPtr
ConstructInt64ArraySmart(uint8_t *data, int64_t size);
ConstructInt64ArraySmart(uint8_t* data, int64_t size);
extern ArrayPtr
ConstructFloatArraySmart(uint8_t *data, int64_t size);
ConstructFloatArraySmart(uint8_t* data, int64_t size);
extern TensorPtr
ConstructFloatTensorSmart(uint8_t *data, int64_t size, std::vector<int64_t> shape);
ConstructFloatTensorSmart(uint8_t* data, int64_t size, std::vector<int64_t> shape);
extern ArrayPtr
ConstructInt64Array(uint8_t *data, int64_t size);
ConstructInt64Array(uint8_t* data, int64_t size);
extern ArrayPtr
ConstructFloatArray(uint8_t *data, int64_t size);
ConstructFloatArray(uint8_t* data, int64_t size);
extern TensorPtr
ConstructFloatTensor(uint8_t *data, int64_t size, std::vector<int64_t> shape);
ConstructFloatTensor(uint8_t* data, int64_t size, std::vector<int64_t> shape);
extern FieldPtr
ConstructInt64Field(const std::string &name);
ConstructInt64Field(const std::string& name);
extern FieldPtr
ConstructFloatField(const std::string &name);
ConstructFloatField(const std::string& name);
}
}
} // namespace knowhere
......@@ -15,18 +15,14 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
namespace zilliz {
namespace knowhere {
#define GETTENSOR(dataset) \
auto tensor = dataset->tensor()[0]; \
auto p_data = tensor->raw_data(); \
auto dim = tensor->shape()[1]; \
auto rows = tensor->shape()[0]; \
#define GETTENSOR(dataset) \
auto tensor = dataset->tensor()[0]; \
auto p_data = tensor->raw_data(); \
auto dim = tensor->shape()[1]; \
auto rows = tensor->shape()[0];
}
}
\ No newline at end of file
} // namespace knowhere
......@@ -15,15 +15,13 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <arrow/array.h>
#include <memory>
#include "Schema.h"
namespace zilliz {
namespace knowhere {
using ArrayData = arrow::ArrayData;
......@@ -35,9 +33,9 @@ using ArrayPtr = std::shared_ptr<Array>;
using BooleanArray = arrow::BooleanArray;
using BooleanArrayPtr = std::shared_ptr<arrow::BooleanArray>;
template<typename DType>
template <typename DType>
using NumericArray = arrow::NumericArray<DType>;
template<typename DType>
template <typename DType>
using NumericArrayPtr = std::shared_ptr<arrow::NumericArray<DType>>;
using BinaryArray = arrow::BinaryArray;
......@@ -49,6 +47,4 @@ using FixedSizeBinaryArrayPtr = std::shared_ptr<arrow::FixedSizeBinaryArray>;
using Decimal128Array = arrow::Decimal128Array;
using Decimal128ArrayPtr = std::shared_ptr<arrow::Decimal128Array>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,21 +15,18 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include "Id.h"
namespace zilliz {
namespace knowhere {
struct Binary {
ID id;
std::shared_ptr<uint8_t> data;
......@@ -37,29 +34,28 @@ struct Binary {
};
using BinaryPtr = std::shared_ptr<Binary>;
class BinarySet {
public:
BinaryPtr
GetByName(const std::string &name) const {
GetByName(const std::string& name) const {
return binary_map_.at(name);
}
void
Append(const std::string &name, BinaryPtr binary) {
Append(const std::string& name, BinaryPtr binary) {
binary_map_[name] = std::move(binary);
}
void
Append(const std::string &name, std::shared_ptr<uint8_t> data, int64_t size) {
Append(const std::string& name, std::shared_ptr<uint8_t> data, int64_t size) {
auto binary = std::make_shared<Binary>();
binary->data = data;
binary->size = size;
binary_map_[name] = std::move(binary);
}
//void
//Append(const std::string &name, void *data, int64_t size, ID id) {
// void
// Append(const std::string &name, void *data, int64_t size, ID id) {
// Binary binary;
// binary.data = data;
// binary.size = size;
......@@ -67,7 +63,8 @@ class BinarySet {
// binary_map_[name] = binary;
//}
void clear() {
void
clear() {
binary_map_.clear();
}
......@@ -75,6 +72,4 @@ class BinarySet {
std::map<std::string, BinaryPtr> binary_map_;
};
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,15 +15,12 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <arrow/buffer.h>
namespace zilliz {
namespace knowhere {
using Buffer = arrow::Buffer;
......@@ -34,31 +31,31 @@ using MutableBufferPtr = std::shared_ptr<MutableBuffer>;
namespace internal {
struct BufferDeleter {
void operator()(Buffer *buffer) {
free((void *) buffer->data());
void
operator()(Buffer* buffer) {
free((void*)buffer->data());
}
};
} // namespace internal
}
inline BufferPtr
MakeBufferSmart(uint8_t *data, const int64_t size) {
MakeBufferSmart(uint8_t* data, const int64_t size) {
return BufferPtr(new Buffer(data, size), internal::BufferDeleter());
}
inline MutableBufferPtr
MakeMutableBufferSmart(uint8_t *data, const int64_t size) {
MakeMutableBufferSmart(uint8_t* data, const int64_t size) {
return MutableBufferPtr(new MutableBuffer(data, size), internal::BufferDeleter());
}
inline BufferPtr
MakeBuffer(uint8_t *data, const int64_t size) {
MakeBuffer(uint8_t* data, const int64_t size) {
return std::make_shared<Buffer>(data, size);
}
inline MutableBufferPtr
MakeMutableBuffer(uint8_t *data, const int64_t size) {
MakeMutableBuffer(uint8_t* data, const int64_t size) {
return std::make_shared<MutableBuffer>(data, size);
}
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,12 +15,10 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
namespace zilliz {
namespace knowhere {
enum class METRICTYPE {
......@@ -42,18 +40,17 @@ struct Cfg {
int64_t gpu_id = DEFAULT_GPUID;
int64_t d = DEFAULT_DIM;
Cfg(const int64_t &dim,
const int64_t &k,
const int64_t &gpu_id,
METRICTYPE type)
: d(dim), k(k), gpu_id(gpu_id), metric_type(type) {}
Cfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, METRICTYPE type)
: metric_type(type), k(k), gpu_id(gpu_id), d(dim) {
}
Cfg() = default;
virtual bool
CheckValid(){};
CheckValid() {
return true;
}
};
using Config = std::shared_ptr<Cfg>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,21 +15,19 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <vector>
#include <memory>
#include <utility>
#include <vector>
#include "Array.h"
#include "Buffer.h"
#include "Tensor.h"
#include "Schema.h"
#include "Config.h"
#include "Schema.h"
#include "Tensor.h"
#include "knowhere/adapter/ArrowAdapter.h"
namespace zilliz {
namespace knowhere {
class Dataset;
......@@ -40,34 +38,38 @@ class Dataset {
public:
Dataset() = default;
Dataset(std::vector<ArrayPtr> &&array, SchemaPtr array_schema,
std::vector<TensorPtr> &&tensor, SchemaPtr tensor_schema)
Dataset(std::vector<ArrayPtr>&& array, SchemaPtr array_schema, std::vector<TensorPtr>&& tensor,
SchemaPtr tensor_schema)
: array_(std::move(array)),
array_schema_(std::move(array_schema)),
tensor_(std::move(tensor)),
tensor_schema_(std::move(tensor_schema)) {}
tensor_schema_(std::move(tensor_schema)) {
}
Dataset(std::vector<ArrayPtr> array, SchemaPtr array_schema)
: array_(std::move(array)), array_schema_(std::move(array_schema)) {}
: array_(std::move(array)), array_schema_(std::move(array_schema)) {
}
Dataset(std::vector<TensorPtr> tensor, SchemaPtr tensor_schema)
: tensor_(std::move(tensor)), tensor_schema_(std::move(tensor_schema)) {}
: tensor_(std::move(tensor)), tensor_schema_(std::move(tensor_schema)) {
}
Dataset(const Dataset &) = delete;
Dataset &operator=(const Dataset &) = delete;
Dataset(const Dataset&) = delete;
Dataset&
operator=(const Dataset&) = delete;
DatasetPtr
Clone() {
auto dataset = std::make_shared<Dataset>();
std::vector<ArrayPtr> clone_array;
for (auto &array : array_) {
for (auto& array : array_) {
clone_array.emplace_back(CopyArray(array));
}
dataset->set_array(clone_array);
std::vector<TensorPtr> clone_tensor;
for (auto &tensor : tensor_) {
for (auto& tensor : tensor_) {
auto buffer = tensor->data();
std::shared_ptr<Buffer> copy_buffer;
// TODO: checkout copy success;
......@@ -86,16 +88,20 @@ class Dataset {
}
public:
const std::vector<ArrayPtr> &
array() const { return array_; }
const std::vector<ArrayPtr>&
array() const {
return array_;
}
void
set_array(std::vector<ArrayPtr> array) {
array_ = std::move(array);
}
const std::vector<TensorPtr> &
tensor() const { return tensor_; }
const std::vector<TensorPtr>&
tensor() const {
return tensor_;
}
void
set_tensor(std::vector<TensorPtr> tensor) {
......@@ -103,7 +109,9 @@ class Dataset {
}
SchemaConstPtr
array_schema() const { return array_schema_; }
array_schema() const {
return array_schema_;
}
void
set_array_schema(SchemaPtr array_schema) {
......@@ -111,31 +119,31 @@ class Dataset {
}
SchemaConstPtr
tensor_schema() const { return tensor_schema_; }
tensor_schema() const {
return tensor_schema_;
}
void
set_tensor_schema(SchemaPtr tensor_schema) {
tensor_schema_ = std::move(tensor_schema);
}
//const Config &
//meta() const { return meta_; }
// const Config &
// meta() const { return meta_; }
//void
//set_meta(Config meta) {
// void
// set_meta(Config meta) {
// meta_ = std::move(meta);
//}
private:
SchemaPtr array_schema_;
SchemaPtr tensor_schema_;
std::vector<ArrayPtr> array_;
SchemaPtr array_schema_;
std::vector<TensorPtr> tensor_;
//Config meta_;
SchemaPtr tensor_schema_;
// Config meta_;
};
using DatasetPtr = std::shared_ptr<Dataset>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,41 +15,35 @@
// specific language governing permissions and limitations
// under the License.
#include <cstdio>
#include "Exception.h"
#include "Log.h"
#include "knowhere/common/Exception.h"
namespace zilliz {
namespace knowhere {
KnowhereException::KnowhereException(const std::string& msg) : msg(msg) {
}
KnowhereException::KnowhereException(const std::string &msg):msg(msg) {}
KnowhereException::KnowhereException(const std::string &m, const char *funcName, const char *file, int line) {
KnowhereException::KnowhereException(const std::string& m, const char* funcName, const char* file, int line) {
#ifdef DEBUG
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s",
funcName, file, line, m.c_str());
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, file, line, m.c_str());
msg.resize(size + 1);
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s",
funcName, file, line, m.c_str());
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", funcName, file, line, m.c_str());
#else
std::string file_path(file);
auto const pos = file_path.find_last_of('/');
auto filename = file_path.substr(pos+1).c_str();
auto filename = file_path.substr(pos + 1).c_str();
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s",
funcName, filename, line, m.c_str());
int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str());
msg.resize(size + 1);
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s",
funcName, filename, line, m.c_str());
snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str());
#endif
}
const char *KnowhereException::what() const noexcept {
const char*
KnowhereException::what() const noexcept {
return msg.c_str();
}
}
}
\ No newline at end of file
} // namespace knowhere
......@@ -15,46 +15,39 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <exception>
#include <string>
namespace zilliz {
namespace knowhere {
class KnowhereException : public std::exception {
public:
explicit KnowhereException(const std::string &msg);
explicit KnowhereException(const std::string& msg);
KnowhereException(const std::string &msg, const char *funName,
const char *file, int line);
KnowhereException(const std::string& msg, const char* funName, const char* file, int line);
const char *what() const noexcept override;
const char*
what() const noexcept override;
std::string msg;
};
#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
#define KNOHWERE_ERROR_MSG(MSG)\
printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
#define KNOWHERE_THROW_MSG(MSG)\
do {\
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__);\
} while (false)
#define KNOHERE_THROW_FORMAT(FMT, ...)\
do { \
std::string __s;\
int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__);\
__s.resize(__size + 1);\
snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__);\
throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__);\
} while (false)
#define KNOWHERE_THROW_MSG(MSG) \
do { \
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
} while (false)
#define KNOHERE_THROW_FORMAT(FMT, ...) \
do { \
std::string __s; \
int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \
__s.resize(__size + 1); \
snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \
throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
} while (false)
}
}
\ No newline at end of file
} // namespace knowhere
......@@ -15,30 +15,27 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
//#include "zcommon/id/id.h"
//using ID = zilliz::common::ID;
#include <stdint.h>
#include <string>
namespace zilliz {
namespace knowhere {
class ID {
public:
constexpr static int64_t kIDSize = 20;
public:
const int32_t *
data() const { return content_; }
const int32_t*
data() const {
return content_;
}
int32_t *
mutable_data() { return content_; }
int32_t*
mutable_data() {
return content_;
}
bool
IsValid() const;
......@@ -47,14 +44,13 @@ class ID {
ToString() const;
bool
operator==(const ID &that) const;
operator==(const ID& that) const;
bool
operator<(const ID &that) const;
operator<(const ID& that) const;
protected:
int32_t content_[5] = {};
};
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,12 +15,10 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "utils/easylogging++.h"
namespace zilliz {
namespace knowhere {
#define KNOWHERE_DOMAIN_NAME "[KNOWHERE] "
......@@ -33,5 +31,4 @@ namespace knowhere {
#define KNOWHERE_LOG_ERROR LOG(ERROR) << KNOWHERE_DOMAIN_NAME
#define KNOWHERE_LOG_FATAL LOG(FATAL) << KNOWHERE_DOMAIN_NAME
} // namespace knowhere
} // namespace zilliz
\ No newline at end of file
} // namespace knowhere
......@@ -15,18 +15,14 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <arrow/type.h>
namespace zilliz {
namespace knowhere {
using DataType = arrow::DataType;
using Field = arrow::Field;
using FieldPtr = std::shared_ptr<arrow::Field>;
......@@ -34,7 +30,4 @@ using Schema = arrow::Schema;
using SchemaPtr = std::shared_ptr<Schema>;
using SchemaConstPtr = std::shared_ptr<const Schema>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,21 +15,15 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <arrow/tensor.h>
namespace zilliz {
namespace knowhere {
using Tensor = arrow::Tensor;
using TensorPtr = std::shared_ptr<Tensor>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,18 +15,13 @@
// specific language governing permissions and limitations
// under the License.
#include <iostream> // TODO(linxj): using Log instead
#include <iostream> // TODO(linxj): using Log instead
#include "knowhere/common/Timer.h"
#include "Timer.h"
namespace zilliz {
namespace knowhere {
TimeRecorder::TimeRecorder(const std::string &header,
int64_t log_level) :
header_(header),
log_level_(log_level) {
TimeRecorder::TimeRecorder(const std::string& header, int64_t log_level) : header_(header), log_level_(log_level) {
start_ = last_ = stdclock::now();
}
......@@ -42,9 +37,10 @@ TimeRecorder::GetTimeSpanStr(double span) {
}
void
TimeRecorder::PrintTimeRecord(const std::string &msg, double span) {
TimeRecorder::PrintTimeRecord(const std::string& msg, double span) {
std::string str_log;
if (!header_.empty()) str_log += header_ + ": ";
if (!header_.empty())
str_log += header_ + ": ";
str_log += msg;
str_log += " (";
str_log += TimeRecorder::GetTimeSpanStr(span);
......@@ -55,35 +51,35 @@ TimeRecorder::PrintTimeRecord(const std::string &msg, double span) {
std::cout << str_log << std::endl;
break;
}
//case 1: {
// SERVER_LOG_DEBUG << str_log;
// break;
//}
//case 2: {
// SERVER_LOG_INFO << str_log;
// break;
//}
//case 3: {
// SERVER_LOG_WARNING << str_log;
// break;
//}
//case 4: {
// SERVER_LOG_ERROR << str_log;
// break;
//}
//case 5: {
// SERVER_LOG_FATAL << str_log;
// break;
//}
//default: {
// SERVER_LOG_INFO << str_log;
// break;
//}
// case 1: {
// SERVER_LOG_DEBUG << str_log;
// break;
//}
// case 2: {
// SERVER_LOG_INFO << str_log;
// break;
//}
// case 3: {
// SERVER_LOG_WARNING << str_log;
// break;
//}
// case 4: {
// SERVER_LOG_ERROR << str_log;
// break;
//}
// case 5: {
// SERVER_LOG_FATAL << str_log;
// break;
//}
// default: {
// SERVER_LOG_INFO << str_log;
// break;
//}
}
}
double
TimeRecorder::RecordSection(const std::string &msg) {
TimeRecorder::RecordSection(const std::string& msg) {
stdclock::time_point curr = stdclock::now();
double span = (std::chrono::duration<double, std::micro>(curr - last_)).count();
last_ = curr;
......@@ -93,7 +89,7 @@ TimeRecorder::RecordSection(const std::string &msg) {
}
double
TimeRecorder::ElapseFromBegin(const std::string &msg) {
TimeRecorder::ElapseFromBegin(const std::string& msg) {
stdclock::time_point curr = stdclock::now();
double span = (std::chrono::duration<double, std::micro>(curr - start_)).count();
......@@ -101,5 +97,4 @@ TimeRecorder::ElapseFromBegin(const std::string &msg) {
return span;
}
}
}
\ No newline at end of file
} // namespace knowhere
......@@ -15,32 +15,33 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <string>
#include <chrono>
#include <string>
namespace zilliz {
namespace knowhere {
class TimeRecorder {
using stdclock = std::chrono::high_resolution_clock;
public:
TimeRecorder(const std::string &header,
int64_t log_level = 0);
explicit TimeRecorder(const std::string& header, int64_t log_level = 0);
~TimeRecorder();//trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5
~TimeRecorder(); // trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5
double RecordSection(const std::string &msg);
double
RecordSection(const std::string& msg);
double ElapseFromBegin(const std::string &msg);
double
ElapseFromBegin(const std::string& msg);
static std::string GetTimeSpanStr(double span);
static std::string
GetTimeSpanStr(double span);
private:
void PrintTimeRecord(const std::string &msg, double span);
void
PrintTimeRecord(const std::string& msg, double span);
private:
std::string header_;
......@@ -49,5 +50,4 @@ class TimeRecorder {
int64_t log_level_;
};
}
}
} // namespace knowhere
......@@ -15,54 +15,53 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "IndexModel.h"
#include "IndexType.h"
#include "knowhere/common/BinarySet.h"
#include "knowhere/common/Dataset.h"
#include "IndexType.h"
#include "IndexModel.h"
#include "knowhere/index/preprocessor/Preprocessor.h"
namespace zilliz {
namespace knowhere {
class Index {
public:
virtual BinarySet
Serialize() = 0;
virtual void
Load(const BinarySet &index_binary) = 0;
Load(const BinarySet& index_binary) = 0;
// @throw
virtual DatasetPtr
Search(const DatasetPtr &dataset, const Config &config) = 0;
Search(const DatasetPtr& dataset, const Config& config) = 0;
public:
IndexType
idx_type() const { return idx_type_; }
idx_type() const {
return idx_type_;
}
void
set_idx_type(IndexType idx_type) { idx_type_ = idx_type; }
set_idx_type(IndexType idx_type) {
idx_type_ = idx_type;
}
virtual void
set_preprocessor(PreprocessorPtr preprocessor) {}
set_preprocessor(PreprocessorPtr preprocessor) {
}
virtual void
set_index_model(IndexModelPtr model) {}
set_index_model(IndexModelPtr model) {
}
private:
IndexType idx_type_;
};
using IndexPtr = std::shared_ptr<Index>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,28 +15,22 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "knowhere/common/BinarySet.h"
namespace zilliz {
namespace knowhere {
class IndexModel {
public:
virtual BinarySet
Serialize() = 0;
virtual void
Load(const BinarySet &binary) = 0;
Load(const BinarySet& binary) = 0;
};
using IndexModelPtr = std::shared_ptr<IndexModel>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,14 +15,10 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
namespace zilliz {
namespace knowhere {
enum class IndexType {
kUnknown = 0,
kVecIdxBegin = 100,
......@@ -30,6 +26,4 @@ enum class IndexType {
kVecIdxEnd,
};
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
//// Licensed to the Apache Software Foundation (ASF) under one
//// or more contributor license agreements. See the NOTICE file
//// distributed with this work for additional information
//// regarding copyright ownership. The ASF licenses this file
//// to you under the Apache License, Version 2.0 (the
//// "License"); you may not use this file except in compliance
//// with the License. You may obtain a copy of the License at
////
//// http://www.apache.org/licenses/LICENSE-2.0
////
//// Unless required by applicable law or agreed to in writing,
//// software distributed under the License is distributed on an
//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
//// KIND, either express or implied. See the License for the
//// specific language governing permissions and limitations
//// under the License.
//
//#include "knowhere/index/vector_index/definitions.h"
//#include "knowhere/common/config.h"
//#include "knowhere/index/preprocessor/normalize.h"
#include "knowhere/index/preprocessor/Normalize.h"
//
//
//namespace zilliz {
//namespace knowhere {
//
//DatasetPtr
//NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
// namespace knowhere {
//
// DatasetPtr
// NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) {
// // TODO: wrap dataset->tensor
// auto tensor = dataset->tensor()[0];
// auto p_data = (float *)tensor->raw_mutable_data();
......@@ -21,8 +37,8 @@
// }
//}
//
//void
//NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
// void
// NormalizePreprocessor::Normalize(float *arr, int64_t dimension) {
// double vector_length = 0;
// for (auto j = 0; j < dimension; j++) {
// double val = arr[j];
......@@ -38,5 +54,4 @@
//}
//
//} // namespace knowhere
//} // namespace zilliz
//
//// Licensed to the Apache Software Foundation (ASF) under one
//// or more contributor license agreements. See the NOTICE file
//// distributed with this work for additional information
//// regarding copyright ownership. The ASF licenses this file
//// to you under the Apache License, Version 2.0 (the
//// "License"); you may not use this file except in compliance
//// with the License. You may obtain a copy of the License at
////
//// http://www.apache.org/licenses/LICENSE-2.0
////
//// Unless required by applicable law or agreed to in writing,
//// software distributed under the License is distributed on an
//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
//// KIND, either express or implied. See the License for the
//// specific language governing permissions and limitations
//// under the License.
//
//#pragma once
//
//#include <memory>
//#include "preprocessor.h"
//
//
//namespace zilliz {
//namespace knowhere {
//
//class NormalizePreprocessor : public Preprocessor {
// namespace knowhere {
//
// class NormalizePreprocessor : public Preprocessor {
// public:
// DatasetPtr
// Preprocess(const DatasetPtr &input) override;
......@@ -19,8 +36,8 @@
//};
//
//
//using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
// using NormalizePreprocessorPtr = std::shared_ptr<NormalizePreprocessor>;
//
//
//} // namespace knowhere
//} // namespace zilliz
//
......@@ -15,27 +15,20 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "knowhere/common/Dataset.h"
namespace zilliz {
namespace knowhere {
class Preprocessor {
public:
virtual DatasetPtr
Preprocess(const DatasetPtr &input) = 0;
Preprocess(const DatasetPtr& input) = 0;
};
using PreprocessorPtr = std::shared_ptr<Preprocessor>;
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,23 +15,23 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/index_io.h>
#include <faiss/IndexIVF.h>
#include <utility>
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/FaissBaseIndex.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
#include "FaissBaseIndex.h"
namespace zilliz {
namespace knowhere {
FaissBaseIndex::FaissBaseIndex(std::shared_ptr<faiss::Index> index) : index_(std::move(index)) {}
FaissBaseIndex::FaissBaseIndex(std::shared_ptr<faiss::Index> index) : index_(std::move(index)) {
}
BinarySet FaissBaseIndex::SerializeImpl() {
BinarySet
FaissBaseIndex::SerializeImpl() {
try {
faiss::Index *index = index_.get();
faiss::Index* index = index_.get();
SealImpl();
......@@ -44,37 +44,37 @@ BinarySet FaissBaseIndex::SerializeImpl() {
// TODO(linxj): use virtual func Name() instead of raw string.
res_set.Append("IVF", data, writer.rp);
return res_set;
} catch (std::exception &e) {
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void FaissBaseIndex::LoadImpl(const BinarySet &index_binary) {
void
FaissBaseIndex::LoadImpl(const BinarySet& index_binary) {
auto binary = index_binary.GetByName("IVF");
MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();
faiss::Index *index = faiss::read_index(&reader);
faiss::Index* index = faiss::read_index(&reader);
index_.reset(index);
}
void FaissBaseIndex::SealImpl() {
// TODO(linxj): enable
//#ifdef ZILLIZ_FAISS
faiss::Index *index = index_.get();
auto idx = dynamic_cast<faiss::IndexIVF *>(index);
void
FaissBaseIndex::SealImpl() {
// TODO(linxj): enable
//#ifdef ZILLIZ_FAISS
faiss::Index* index = index_.get();
auto idx = dynamic_cast<faiss::IndexIVF*>(index);
if (idx != nullptr) {
idx->to_readonly();
}
//else {
// else {
// KNOHWERE_ERROR_MSG("Seal failed");
//}
//#endif
//#endif
}
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
......@@ -24,8 +23,6 @@
#include "knowhere/common/BinarySet.h"
namespace zilliz {
namespace knowhere {
class FaissBaseIndex {
......@@ -36,7 +33,7 @@ class FaissBaseIndex {
SerializeImpl();
virtual void
LoadImpl(const BinarySet &index_binary);
LoadImpl(const BinarySet& index_binary);
virtual void
SealImpl();
......@@ -45,8 +42,4 @@ class FaissBaseIndex {
std::shared_ptr<faiss::Index> index_ = nullptr;
};
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,30 +15,27 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/index_io.h>
#include <memory>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/helpers/Cloner.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "IndexGPUIVF.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
IndexModelPtr
GPUIVF::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
......@@ -49,10 +46,9 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
ResScope rs(temp_resource, gpu_id_, true);
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_id_;
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim,
build_cfg->nlist, GetMetricType(build_cfg->metric_type),
idx_config);
device_index.train(rows, (float *) p_data);
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type), idx_config);
device_index.train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
......@@ -63,7 +59,8 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
}
}
void GPUIVF::set_index_model(IndexModelPtr model) {
void
GPUIVF::set_index_model(IndexModelPtr model) {
std::lock_guard<std::mutex> lk(mutex_);
auto host_index = std::static_pointer_cast<IVFIndexModel>(model);
......@@ -77,7 +74,8 @@ void GPUIVF::set_index_model(IndexModelPtr model) {
}
}
BinarySet GPUIVF::SerializeImpl() {
BinarySet
GPUIVF::SerializeImpl() {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
......@@ -85,8 +83,8 @@ BinarySet GPUIVF::SerializeImpl() {
try {
MemoryIOWriter writer;
{
faiss::Index *index = index_.get();
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(index);
faiss::Index* index = index_.get();
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index);
SealImpl();
......@@ -100,19 +98,20 @@ BinarySet GPUIVF::SerializeImpl() {
res_set.Append("IVF", data, writer.rp);
return res_set;
} catch (std::exception &e) {
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void GPUIVF::LoadImpl(const BinarySet &index_binary) {
void
GPUIVF::LoadImpl(const BinarySet& index_binary) {
auto binary = index_binary.GetByName("IVF");
MemoryIOReader reader;
{
reader.total = binary->size;
reader.data_ = binary->data.get();
faiss::Index *index = faiss::read_index(&reader);
faiss::Index* index = faiss::read_index(&reader);
if (auto temp_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
ResScope rs(temp_res, gpu_id_, false);
......@@ -127,12 +126,8 @@ void GPUIVF::LoadImpl(const BinarySet &index_binary) {
}
}
void GPUIVF::search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) {
void
GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
std::lock_guard<std::mutex> lk(mutex_);
// TODO(linxj): gpu index support GenParams
......@@ -143,19 +138,20 @@ void GPUIVF::search_impl(int64_t n,
{
// TODO(linxj): allocate gpu mem
ResScope rs(res_, gpu_id_);
device_index->search(n, (float *) data, k, distances, labels);
device_index->search(n, (float*)data, k, distances, labels);
}
} else {
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");
}
}
VectorIndexPtr GPUIVF::CopyGpuToCpu(const Config &config) {
VectorIndexPtr
GPUIVF::CopyGpuToCpu(const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
if ( auto device_idx = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
faiss::Index *device_index = index_.get();
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
if (auto device_idx = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
faiss::Index* device_index = index_.get();
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
std::shared_ptr<faiss::Index> new_index;
new_index.reset(host_index);
......@@ -165,33 +161,36 @@ VectorIndexPtr GPUIVF::CopyGpuToCpu(const Config &config) {
}
}
VectorIndexPtr GPUIVF::Clone() {
VectorIndexPtr
GPUIVF::Clone() {
auto cpu_idx = CopyGpuToCpu(Config());
return ::zilliz::knowhere::cloner::CopyCpuToGpu(cpu_idx, gpu_id_, Config());
return knowhere::cloner::CopyCpuToGpu(cpu_idx, gpu_id_, Config());
}
VectorIndexPtr GPUIVF::CopyGpuToGpu(const int64_t &device_id, const Config &config) {
VectorIndexPtr
GPUIVF::CopyGpuToGpu(const int64_t& device_id, const Config& config) {
auto host_index = CopyGpuToCpu(config);
return std::static_pointer_cast<IVF>(host_index)->CopyCpuToGpu(device_id, config);
}
void GPUIVF::Add(const DatasetPtr &dataset, const Config &config) {
void
GPUIVF::Add(const DatasetPtr& dataset, const Config& config) {
if (auto spt = res_.lock()) {
ResScope rs(res_, gpu_id_);
IVF::Add(dataset, config);
}
else {
} else {
KNOWHERE_THROW_MSG("Add IVF can't get gpu resource");
}
}
void GPUIndex::SetGpuDevice(const int &gpu_id) {
void
GPUIndex::SetGpuDevice(const int& gpu_id) {
gpu_id_ = gpu_id;
}
const int64_t &GPUIndex::GetGpuDevice() {
const int64_t&
GPUIndex::GetGpuDevice() {
return gpu_id_;
}
}
}
} // namespace knowhere
......@@ -15,81 +15,78 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <utility>
#include "IndexIVF.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
namespace zilliz {
namespace knowhere {
class GPUIndex {
public:
explicit GPUIndex(const int &device_id) : gpu_id_(device_id) {}
public:
explicit GPUIndex(const int& device_id) : gpu_id_(device_id) {
}
GPUIndex(const int& device_id, const ResPtr& resource): gpu_id_(device_id), res_(resource) {}
GPUIndex(const int& device_id, const ResPtr& resource) : gpu_id_(device_id), res_(resource) {
}
virtual VectorIndexPtr
CopyGpuToCpu(const Config &config) = 0;
virtual VectorIndexPtr
CopyGpuToCpu(const Config& config) = 0;
virtual VectorIndexPtr
CopyGpuToGpu(const int64_t &device_id, const Config &config) = 0;
virtual VectorIndexPtr
CopyGpuToGpu(const int64_t& device_id, const Config& config) = 0;
void
SetGpuDevice(const int &gpu_id);
void
SetGpuDevice(const int& gpu_id);
const int64_t &
GetGpuDevice();
const int64_t&
GetGpuDevice();
protected:
int64_t gpu_id_;
ResWPtr res_;
protected:
int64_t gpu_id_;
ResWPtr res_;
};
class GPUIVF : public IVF, public GPUIndex {
public:
explicit GPUIVF(const int &device_id) : IVF(), GPUIndex(device_id) {}
public:
explicit GPUIVF(const int& device_id) : IVF(), GPUIndex(device_id) {
}
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr &resource)
: IVF(std::move(index)), GPUIndex(device_id, resource) {};
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
: IVF(std::move(index)), GPUIndex(device_id, resource) {
}
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
IndexModelPtr
Train(const DatasetPtr& dataset, const Config& config) override;
void
Add(const DatasetPtr &dataset, const Config &config) override;
void
Add(const DatasetPtr& dataset, const Config& config) override;
void
set_index_model(IndexModelPtr model) override;
void
set_index_model(IndexModelPtr model) override;
//DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
VectorIndexPtr
CopyGpuToCpu(const Config &config) override;
// DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
VectorIndexPtr
CopyGpuToCpu(const Config& config) override;
VectorIndexPtr
CopyGpuToGpu(const int64_t &device_id, const Config &config) override;
VectorIndexPtr
CopyGpuToGpu(const int64_t& device_id, const Config& config) override;
VectorIndexPtr
Clone() final;
VectorIndexPtr
Clone() final;
protected:
void
search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) override;
protected:
void
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override;
BinarySet
SerializeImpl() override;
BinarySet
SerializeImpl() override;
void
LoadImpl(const BinarySet &index_binary) override;
void
LoadImpl(const BinarySet& index_binary) override;
};
} // namespace knowhere
} // namespace zilliz
} // namespace knowhere
......@@ -15,23 +15,23 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/gpu/GpuIndexIVFPQ.h>
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/gpu/GpuIndexIVFPQ.h>
#include <memory>
#include "IndexGPUIVFPQ.h"
#include "knowhere/common/Exception.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
IndexModelPtr
GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
......@@ -40,9 +40,9 @@ IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
// TODO(linxj): set device here.
// TODO(linxj): set gpu resource here.
faiss::gpu::StandardGpuResources res;
faiss::gpu::GpuIndexIVFPQ device_index(&res, dim, build_cfg->nlist, build_cfg->m,
build_cfg->nbits, GetMetricType(build_cfg->metric_type)); // IP not support
device_index.train(rows, (float *) p_data);
faiss::gpu::GpuIndexIVFPQ device_index(&res, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits,
GetMetricType(build_cfg->metric_type)); // IP not support
device_index.train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
......@@ -50,20 +50,21 @@ IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
return std::make_shared<IVFIndexModel>(host_index);
}
std::shared_ptr<faiss::IVFSearchParameters> GPUIVFPQ::GenParams(const Config &config) {
std::shared_ptr<faiss::IVFSearchParameters>
GPUIVFPQ::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
return params;
}
VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) {
VectorIndexPtr
GPUIVFPQ::CopyGpuToCpu(const Config& config) {
KNOWHERE_THROW_MSG("not support yet");
}
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,33 +15,30 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "IndexGPUIVF.h"
namespace zilliz {
namespace knowhere {
class GPUIVFPQ : public GPUIVF {
public:
explicit GPUIVFPQ(const int &device_id) : GPUIVF(device_id) {}
public:
explicit GPUIVFPQ(const int& device_id) : GPUIVF(device_id) {
}
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
Train(const DatasetPtr& dataset, const Config& config) override;
public:
public:
VectorIndexPtr
CopyGpuToCpu(const Config &config) override;
CopyGpuToCpu(const Config& config) override;
protected:
protected:
// TODO(linxj): remove GenParams.
std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config &config) override;
GenParams(const Config& config) override;
};
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,60 +15,60 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/gpu/GpuAutoTune.h>
#include <memory>
#include <utility>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "IndexGPUIVFSQ.h"
#include "IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
IndexModelPtr
GPUIVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
GETTENSOR(dataset)
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << "," << "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << ","
<< "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float *) p_data);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
delete device_index;
delete build_index;
delete device_index;
delete build_index;
return std::make_shared<IVFIndexModel>(host_index);
} else {
KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource");
}
return std::make_shared<IVFIndexModel>(host_index);
} else {
KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource");
}
}
VectorIndexPtr GPUIVFSQ::CopyGpuToCpu(const Config &config) {
std::lock_guard<std::mutex> lk(mutex_);
VectorIndexPtr
GPUIVFSQ::CopyGpuToCpu(const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
faiss::Index *device_index = index_.get();
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
std::shared_ptr<faiss::Index> new_index;
new_index.reset(host_index);
return std::make_shared<IVFSQ>(new_index);
}
faiss::Index* device_index = index_.get();
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
} // knowhere
} // zilliz
std::shared_ptr<faiss::Index> new_index;
new_index.reset(host_index);
return std::make_shared<IVFSQ>(new_index);
}
} // namespace knowhere
......@@ -15,29 +15,29 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "IndexGPUIVF.h"
#include <memory>
#include <utility>
#include "IndexGPUIVF.h"
namespace zilliz {
namespace knowhere {
class GPUIVFSQ : public GPUIVF {
public:
explicit GPUIVFSQ(const int &device_id) : GPUIVF(device_id) {}
public:
explicit GPUIVFSQ(const int& device_id) : GPUIVF(device_id) {
}
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr &resource)
: GPUIVF(std::move(index), device_id, resource) {};
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
: GPUIVF(std::move(index), device_id, resource) {
}
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
Train(const DatasetPtr& dataset, const Config& config) override;
VectorIndexPtr
CopyGpuToCpu(const Config &config) override;
CopyGpuToCpu(const Config& config) override;
};
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,24 +15,22 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/IndexFlat.h>
#include <faiss/AutoTune.h>
#include <faiss/IndexFlat.h>
#include <faiss/MetaIndexes.h>
#include <faiss/index_io.h>
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/index_io.h>
#include <vector>
#include "knowhere/common/Exception.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexIDMAP.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"
#include "IndexIDMAP.h"
namespace zilliz {
namespace knowhere {
BinarySet IDMAP::Serialize() {
BinarySet
IDMAP::Serialize() {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
......@@ -41,31 +39,33 @@ BinarySet IDMAP::Serialize() {
return SerializeImpl();
}
void IDMAP::Load(const BinarySet &index_binary) {
void
IDMAP::Load(const BinarySet& index_binary) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary);
}
DatasetPtr IDMAP::Search(const DatasetPtr &dataset, const Config &config) {
DatasetPtr
IDMAP::Search(const DatasetPtr& dataset, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
config->CheckValid();
//auto metric_type = config["metric_type"].as_string() == "L2" ?
// auto metric_type = config["metric_type"].as_string() == "L2" ?
// faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
//index_->metric_type = metric_type;
// index_->metric_type = metric_type;
GETTENSOR(dataset)
auto elems = rows * config->k;
auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems);
auto res_dis = (float *) malloc(sizeof(float) * elems);
auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems);
auto res_dis = (float*)malloc(sizeof(float) * elems);
search_impl(rows, (float *) p_data, config->k, res_dis, res_ids, Config());
search_impl(rows, (float*)p_data, config->k, res_dis, res_ids, Config());
auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems);
auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems);
auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems);
auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems);
std::vector<BufferPtr> id_bufs{nullptr, id_buf};
std::vector<BufferPtr> dist_bufs{nullptr, dist_buf};
......@@ -83,12 +83,13 @@ DatasetPtr IDMAP::Search(const DatasetPtr &dataset, const Config &config) {
return std::make_shared<Dataset>(array, nullptr);
}
void IDMAP::search_impl(int64_t n, const float *data, int64_t k, float *distances, int64_t *labels, const Config &cfg) {
index_->search(n, (float *) data, k, distances, labels);
void
IDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
index_->search(n, (float*)data, k, distances, labels);
}
void IDMAP::Add(const DatasetPtr &dataset, const Config &config) {
void
IDMAP::Add(const DatasetPtr& dataset, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}
......@@ -98,49 +99,56 @@ void IDMAP::Add(const DatasetPtr &dataset, const Config &config) {
// TODO: magic here.
auto array = dataset->array()[0];
auto p_ids = array->data()->GetValues<long>(1, 0);
auto p_ids = array->data()->GetValues<int64_t>(1, 0);
index_->add_with_ids(rows, (float *) p_data, p_ids);
index_->add_with_ids(rows, (float*)p_data, p_ids);
}
int64_t IDMAP::Count() {
int64_t
IDMAP::Count() {
return index_->ntotal;
}
int64_t IDMAP::Dimension() {
int64_t
IDMAP::Dimension() {
return index_->d;
}
// TODO(linxj): return const pointer
float *IDMAP::GetRawVectors() {
float*
IDMAP::GetRawVectors() {
try {
auto file_index = dynamic_cast<faiss::IndexIDMap *>(index_.get());
auto file_index = dynamic_cast<faiss::IndexIDMap*>(index_.get());
auto flat_index = dynamic_cast<faiss::IndexFlat*>(file_index->index);
return flat_index->xb.data();
} catch (std::exception &e) {
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
// TODO(linxj): return const pointer
int64_t *IDMAP::GetRawIds() {
int64_t*
IDMAP::GetRawIds() {
try {
auto file_index = dynamic_cast<faiss::IndexIDMap *>(index_.get());
auto file_index = dynamic_cast<faiss::IndexIDMap*>(index_.get());
return file_index->id_map.data();
} catch (std::exception &e) {
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
const char* type = "IDMap,Flat";
void IDMAP::Train(const Config &config) {
void
IDMAP::Train(const Config& config) {
config->CheckValid();
auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type));
index_.reset(index);
}
VectorIndexPtr IDMAP::Clone() {
VectorIndexPtr
IDMAP::Clone() {
std::lock_guard<std::mutex> lk(mutex_);
auto clone_index = faiss::clone_index(index_.get());
......@@ -149,8 +157,9 @@ VectorIndexPtr IDMAP::Clone() {
return std::make_shared<IDMAP>(new_index);
}
VectorIndexPtr IDMAP::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
VectorIndexPtr
IDMAP::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
......@@ -162,38 +171,41 @@ VectorIndexPtr IDMAP::CopyCpuToGpu(const int64_t &device_id, const Config &confi
}
}
void IDMAP::Seal() {
void
IDMAP::Seal() {
// do nothing
}
VectorIndexPtr GPUIDMAP::CopyGpuToCpu(const Config &config) {
VectorIndexPtr
GPUIDMAP::CopyGpuToCpu(const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
faiss::Index *device_index = index_.get();
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
faiss::Index* device_index = index_.get();
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
std::shared_ptr<faiss::Index> new_index;
new_index.reset(host_index);
return std::make_shared<IDMAP>(new_index);
}
VectorIndexPtr GPUIDMAP::Clone() {
VectorIndexPtr
GPUIDMAP::Clone() {
auto cpu_idx = CopyGpuToCpu(Config());
if (auto idmap = std::dynamic_pointer_cast<IDMAP>(cpu_idx)){
if (auto idmap = std::dynamic_pointer_cast<IDMAP>(cpu_idx)) {
return idmap->CopyCpuToGpu(gpu_id_, Config());
}
else {
} else {
KNOWHERE_THROW_MSG("IndexType not Support GpuClone");
}
}
BinarySet GPUIDMAP::SerializeImpl() {
BinarySet
GPUIDMAP::SerializeImpl() {
try {
MemoryIOWriter writer;
{
faiss::Index *index = index_.get();
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(index);
faiss::Index* index = index_.get();
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index);
faiss::write_index(host_index, &writer);
delete host_index;
......@@ -205,21 +217,22 @@ BinarySet GPUIDMAP::SerializeImpl() {
res_set.Append("IVF", data, writer.rp);
return res_set;
} catch (std::exception &e) {
} catch (std::exception& e) {
KNOWHERE_THROW_MSG(e.what());
}
}
void GPUIDMAP::LoadImpl(const BinarySet &index_binary) {
void
GPUIDMAP::LoadImpl(const BinarySet& index_binary) {
auto binary = index_binary.GetByName("IVF");
MemoryIOReader reader;
{
reader.total = binary->size;
reader.data_ = binary->data.get();
faiss::Index *index = faiss::read_index(&reader);
faiss::Index* index = faiss::read_index(&reader);
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_) ){
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
ResScope rs(res, gpu_id_, false);
auto device_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index);
index_.reset(device_index);
......@@ -232,28 +245,26 @@ void GPUIDMAP::LoadImpl(const BinarySet &index_binary) {
}
}
VectorIndexPtr GPUIDMAP::CopyGpuToGpu(const int64_t &device_id, const Config &config) {
VectorIndexPtr
GPUIDMAP::CopyGpuToGpu(const int64_t& device_id, const Config& config) {
auto cpu_index = CopyGpuToCpu(config);
return std::static_pointer_cast<IDMAP>(cpu_index)->CopyCpuToGpu(device_id, config);
}
float *GPUIDMAP::GetRawVectors() {
float*
GPUIDMAP::GetRawVectors() {
KNOWHERE_THROW_MSG("Not support");
}
int64_t *GPUIDMAP::GetRawIds() {
int64_t*
GPUIDMAP::GetRawIds() {
KNOWHERE_THROW_MSG("Not support");
}
void GPUIDMAP::search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) {
void
GPUIDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
ResScope rs(res_, gpu_id_);
index_->search(n, (float *) data, k, distances, labels);
index_->search(n, (float*)data, k, distances, labels);
}
}
}
} // namespace knowhere
......@@ -15,41 +15,53 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "IndexIVF.h"
#include "IndexGPUIVF.h"
#include "IndexIVF.h"
#include <memory>
#include <utility>
namespace zilliz {
namespace knowhere {
class IDMAP : public VectorIndex, public FaissBaseIndex {
public:
IDMAP() : FaissBaseIndex(nullptr) {};
explicit IDMAP(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {};
BinarySet Serialize() override;
void Load(const BinarySet &index_binary) override;
void Train(const Config &config);
DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
int64_t Count() override;
VectorIndexPtr Clone() override;
int64_t Dimension() override;
void Add(const DatasetPtr &dataset, const Config &config) override;
VectorIndexPtr CopyCpuToGpu(const int64_t &device_id, const Config &config);
void Seal() override;
IDMAP() : FaissBaseIndex(nullptr) {
}
explicit IDMAP(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
}
BinarySet
Serialize() override;
void
Load(const BinarySet& index_binary) override;
void
Train(const Config& config);
DatasetPtr
Search(const DatasetPtr& dataset, const Config& config) override;
int64_t
Count() override;
VectorIndexPtr
Clone() override;
int64_t
Dimension() override;
void
Add(const DatasetPtr& dataset, const Config& config) override;
VectorIndexPtr
CopyCpuToGpu(const int64_t& device_id, const Config& config);
void
Seal() override;
virtual float *GetRawVectors();
virtual int64_t *GetRawIds();
virtual float*
GetRawVectors();
virtual int64_t*
GetRawIds();
protected:
virtual void search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg);
virtual void
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
std::mutex mutex_;
};
......@@ -57,27 +69,30 @@ using IDMAPPtr = std::shared_ptr<IDMAP>;
class GPUIDMAP : public IDMAP, public GPUIndex {
public:
explicit GPUIDMAP(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr& res)
: IDMAP(std::move(index)), GPUIndex(device_id, res) {}
explicit GPUIDMAP(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& res)
: IDMAP(std::move(index)), GPUIndex(device_id, res) {
}
VectorIndexPtr CopyGpuToCpu(const Config &config) override;
float *GetRawVectors() override;
int64_t *GetRawIds() override;
VectorIndexPtr Clone() override;
VectorIndexPtr CopyGpuToGpu(const int64_t &device_id, const Config &config) override;
VectorIndexPtr
CopyGpuToCpu(const Config& config) override;
float*
GetRawVectors() override;
int64_t*
GetRawIds() override;
VectorIndexPtr
Clone() override;
VectorIndexPtr
CopyGpuToGpu(const int64_t& device_id, const Config& config) override;
protected:
void search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) override;
BinarySet SerializeImpl() override;
void LoadImpl(const BinarySet &index_binary) override;
void
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override;
BinarySet
SerializeImpl() override;
void
LoadImpl(const BinarySet& index_binary) override;
};
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;
}
}
} // namespace knowhere
......@@ -15,47 +15,46 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/AutoTune.h>
#include <faiss/AuxIndexStructures.h>
#include <faiss/IVFlib.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVF.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/AutoTune.h>
#include <faiss/IVFlib.h>
#include <faiss/AuxIndexStructures.h>
#include <faiss/index_io.h>
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/index_io.h>
#include <memory>
#include <utility>
#include <vector>
#include "knowhere/common/Exception.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "IndexIVF.h"
#include "IndexGPUIVF.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexIVF.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr IVF::Train(const DatasetPtr &dataset, const Config &config) {
IndexModelPtr
IVF::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
faiss::Index *coarse_quantizer = new faiss::IndexFlatL2(dim);
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim,
build_cfg->nlist,
faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim);
auto index = std::make_shared<faiss::IndexIVFFlat>(coarse_quantizer, dim, build_cfg->nlist,
GetMetricType(build_cfg->metric_type));
index->train(rows, (float *) p_data);
index->train(rows, (float*)p_data);
// TODO(linxj): override here. train return model or not.
return std::make_shared<IVFIndexModel>(index);
}
void IVF::Add(const DatasetPtr &dataset, const Config &config) {
void
IVF::Add(const DatasetPtr& dataset, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
......@@ -64,11 +63,12 @@ void IVF::Add(const DatasetPtr &dataset, const Config &config) {
GETTENSOR(dataset)
auto array = dataset->array()[0];
auto p_ids = array->data()->GetValues<long>(1, 0);
index_->add_with_ids(rows, (float *) p_data, p_ids);
auto p_ids = array->data()->GetValues<int64_t>(1, 0);
index_->add_with_ids(rows, (float*)p_data, p_ids);
}
void IVF::AddWithoutIds(const DatasetPtr &dataset, const Config &config) {
void
IVF::AddWithoutIds(const DatasetPtr& dataset, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
......@@ -76,10 +76,11 @@ void IVF::AddWithoutIds(const DatasetPtr &dataset, const Config &config) {
std::lock_guard<std::mutex> lk(mutex_);
GETTENSOR(dataset)
index_->add(rows, (float *) p_data);
index_->add(rows, (float*)p_data);
}
BinarySet IVF::Serialize() {
BinarySet
IVF::Serialize() {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
......@@ -89,31 +90,33 @@ BinarySet IVF::Serialize() {
return SerializeImpl();
}
void IVF::Load(const BinarySet &index_binary) {
void
IVF::Load(const BinarySet& index_binary) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(index_binary);
}
DatasetPtr IVF::Search(const DatasetPtr &dataset, const Config &config) {
DatasetPtr
IVF::Search(const DatasetPtr& dataset, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
if (search_cfg != nullptr) {
search_cfg->CheckValid(); // throw exception
search_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
auto elems = rows * search_cfg->k;
auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems);
auto res_dis = (float *) malloc(sizeof(float) * elems);
auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems);
auto res_dis = (float*)malloc(sizeof(float) * elems);
search_impl(rows, (float*) p_data, search_cfg->k, res_dis, res_ids, config);
search_impl(rows, (float*)p_data, search_cfg->k, res_dis, res_ids, config);
auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems);
auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems);
auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems);
auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems);
std::vector<BufferPtr> id_bufs{nullptr, id_buf};
std::vector<BufferPtr> dist_bufs{nullptr, dist_buf};
......@@ -131,7 +134,8 @@ DatasetPtr IVF::Search(const DatasetPtr &dataset, const Config &config) {
return std::make_shared<Dataset>(array, nullptr);
}
void IVF::set_index_model(IndexModelPtr model) {
void
IVF::set_index_model(IndexModelPtr model) {
std::lock_guard<std::mutex> lk(mutex_);
auto rel_model = std::static_pointer_cast<IVFIndexModel>(model);
......@@ -140,25 +144,29 @@ void IVF::set_index_model(IndexModelPtr model) {
index_.reset(faiss::clone_index(rel_model->index_.get()));
}
std::shared_ptr<faiss::IVFSearchParameters> IVF::GenParams(const Config &config) {
std::shared_ptr<faiss::IVFSearchParameters>
IVF::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFCfg>(config);
params->nprobe = search_cfg->nprobe;
//params->max_codes = config.get_with_default("max_codes", size_t(0));
// params->max_codes = config.get_with_default("max_codes", size_t(0));
return params;
}
int64_t IVF::Count() {
int64_t
IVF::Count() {
return index_->ntotal;
}
int64_t IVF::Dimension() {
int64_t
IVF::Dimension() {
return index_->d;
}
void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, const Config &config) {
void
IVF::GenGraph(const int64_t& k, Graph& graph, const DatasetPtr& dataset, const Config& config) {
GETTENSOR(dataset)
auto ntotal = Count();
......@@ -174,7 +182,7 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co
for (int i = 0; i < total_search_count; ++i) {
auto b_size = i == total_search_count - 1 && tail_batch_size != 0 ? tail_batch_size : batch_size;
auto &res = res_vec[i];
auto& res = res_vec[i];
res.resize(k * b_size);
auto xq = p_data + batch_size * dim * i;
......@@ -182,7 +190,7 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co
int tmp = 0;
for (int j = 0; j < b_size; ++j) {
auto &node = graph[batch_size * i + j];
auto& node = graph[batch_size * i + j];
node.resize(k);
for (int m = 0; m < k && tmp < k * b_size; ++m, ++tmp) {
// TODO(linxj): avoid memcopy here.
......@@ -192,18 +200,15 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co
}
}
void IVF::search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) {
void
IVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) {
auto params = GenParams(cfg);
faiss::ivflib::search_with_parameters(index_.get(), n, (float *) data, k, distances, labels, params.get());
faiss::ivflib::search_with_parameters(index_.get(), n, (float*)data, k, distances, labels, params.get());
}
VectorIndexPtr IVF::CopyCpuToGpu(const int64_t& device_id, const Config &config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
VectorIndexPtr
IVF::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
......@@ -215,7 +220,8 @@ VectorIndexPtr IVF::CopyCpuToGpu(const int64_t& device_id, const Config &config)
}
}
VectorIndexPtr IVF::Clone() {
VectorIndexPtr
IVF::Clone() {
std::lock_guard<std::mutex> lk(mutex_);
auto clone_index = faiss::clone_index(index_.get());
......@@ -224,21 +230,24 @@ VectorIndexPtr IVF::Clone() {
return Clone_impl(new_index);
}
VectorIndexPtr IVF::Clone_impl(const std::shared_ptr<faiss::Index> &index) {
VectorIndexPtr
IVF::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
return std::make_shared<IVF>(index);
}
void IVF::Seal() {
void
IVF::Seal() {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}
SealImpl();
}
IVFIndexModel::IVFIndexModel(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
}
IVFIndexModel::IVFIndexModel(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {}
BinarySet IVFIndexModel::Serialize() {
BinarySet
IVFIndexModel::Serialize() {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("indexmodel not initialize or trained");
}
......@@ -246,18 +255,15 @@ BinarySet IVFIndexModel::Serialize() {
return SerializeImpl();
}
void IVFIndexModel::Load(const BinarySet &binary_set) {
void
IVFIndexModel::Load(const BinarySet& binary_set) {
std::lock_guard<std::mutex> lk(mutex_);
LoadImpl(binary_set);
}
void IVFIndexModel::SealImpl() {
void
IVFIndexModel::SealImpl() {
// do nothing
}
}
}
} // namespace knowhere
......@@ -15,54 +15,55 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "VectorIndex.h"
#include "FaissBaseIndex.h"
#include "VectorIndex.h"
#include "faiss/IndexIVF.h"
namespace zilliz {
namespace knowhere {
using Graph = std::vector<std::vector<int64_t>>;
class IVF : public VectorIndex, protected FaissBaseIndex {
public:
IVF() : FaissBaseIndex(nullptr) {};
IVF() : FaissBaseIndex(nullptr) {
}
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {}
explicit IVF(std::shared_ptr<faiss::Index> index) : FaissBaseIndex(std::move(index)) {
}
VectorIndexPtr
Clone() override;;
Clone() override;
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
Train(const DatasetPtr& dataset, const Config& config) override;
void
set_index_model(IndexModelPtr model) override;
void
Add(const DatasetPtr &dataset, const Config &config) override;
Add(const DatasetPtr& dataset, const Config& config) override;
void
AddWithoutIds(const DatasetPtr &dataset, const Config &config);
AddWithoutIds(const DatasetPtr& dataset, const Config& config);
DatasetPtr
Search(const DatasetPtr &dataset, const Config &config) override;
Search(const DatasetPtr& dataset, const Config& config) override;
void
GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, const Config &config);
GenGraph(const int64_t& k, Graph& graph, const DatasetPtr& dataset, const Config& config);
BinarySet
Serialize() override;
void
Load(const BinarySet &index_binary) override;
Load(const BinarySet& index_binary) override;
int64_t
Count() override;
......@@ -74,23 +75,17 @@ class IVF : public VectorIndex, protected FaissBaseIndex {
Seal() override;
virtual VectorIndexPtr
CopyCpuToGpu(const int64_t &device_id, const Config &config);
CopyCpuToGpu(const int64_t& device_id, const Config& config);
protected:
virtual std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config &config);
GenParams(const Config& config);
virtual VectorIndexPtr
Clone_impl(const std::shared_ptr<faiss::Index> &index);
Clone_impl(const std::shared_ptr<faiss::Index>& index);
virtual void
search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg);
search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg);
protected:
std::mutex mutex_;
......@@ -106,13 +101,14 @@ class IVFIndexModel : public IndexModel, public FaissBaseIndex {
public:
explicit IVFIndexModel(std::shared_ptr<faiss::Index> index);
IVFIndexModel() : FaissBaseIndex(nullptr) {};
IVFIndexModel() : FaissBaseIndex(nullptr) {
}
BinarySet
Serialize() override;
void
Load(const BinarySet &binary) override;
Load(const BinarySet& binary) override;
protected:
void
......@@ -121,7 +117,7 @@ class IVFIndexModel : public IndexModel, public FaissBaseIndex {
protected:
std::mutex mutex_;
};
using IVFIndexModelPtr = std::shared_ptr<IVFIndexModel>;
}
}
\ No newline at end of file
} // namespace knowhere
......@@ -15,47 +15,49 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <memory>
#include <utility>
#include "IndexIVFPQ.h"
#include "knowhere/common/Exception.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr IVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
IndexModelPtr
IVFPQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
faiss::Index *coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type));
auto index = std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim,
build_cfg->nlist, build_cfg->m, build_cfg->nbits);
index->train(rows, (float *) p_data);
faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type));
auto index =
std::make_shared<faiss::IndexIVFPQ>(coarse_quantizer, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits);
index->train(rows, (float*)p_data);
return std::make_shared<IVFIndexModel>(index);
}
std::shared_ptr<faiss::IVFSearchParameters> IVFPQ::GenParams(const Config &config) {
std::shared_ptr<faiss::IVFSearchParameters>
IVFPQ::GenParams(const Config& config) {
auto params = std::make_shared<faiss::IVFPQSearchParameters>();
auto search_cfg = std::dynamic_pointer_cast<IVFPQCfg>(config);
params->nprobe = search_cfg->nprobe;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
// params->scan_table_threshold = conf->scan_table_threhold;
// params->polysemous_ht = conf->polysemous_ht;
// params->max_codes = conf->max_codes;
return params;
}
VectorIndexPtr IVFPQ::Clone_impl(const std::shared_ptr<faiss::Index> &index) {
VectorIndexPtr
IVFPQ::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
return std::make_shared<IVFPQ>(index);
}
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,33 +15,31 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <utility>
#include "IndexIVF.h"
namespace zilliz {
namespace knowhere {
class IVFPQ : public IVF {
public:
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {}
public:
explicit IVFPQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
}
IVFPQ() = default;
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
Train(const DatasetPtr& dataset, const Config& config) override;
protected:
protected:
std::shared_ptr<faiss::IVFSearchParameters>
GenParams(const Config &config) override;
GenParams(const Config& config) override;
VectorIndexPtr
Clone_impl(const std::shared_ptr<faiss::Index> &index) override;
Clone_impl(const std::shared_ptr<faiss::Index>& index) override;
};
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,44 +15,45 @@
// specific language governing permissions and limitations
// under the License.
#include <faiss/gpu/GpuAutoTune.h>
#include <memory>
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "IndexIVFSQ.h"
#include "IndexGPUIVFSQ.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr IVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
IndexModelPtr
IVFSQ::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
build_cfg->CheckValid(); // throw exception
}
GETTENSOR(dataset)
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << "," << "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(),
GetMetricType(build_cfg->metric_type));
build_index->train(rows, (float *) p_data);
index_type << "IVF" << build_cfg->nlist << ","
<< "SQ" << build_cfg->nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
build_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> ret_index;
ret_index.reset(build_index);
return std::make_shared<IVFIndexModel>(ret_index);
}
VectorIndexPtr IVFSQ::Clone_impl(const std::shared_ptr<faiss::Index> &index) {
VectorIndexPtr
IVFSQ::Clone_impl(const std::shared_ptr<faiss::Index>& index) {
return std::make_shared<IVFSQ>(index);
}
VectorIndexPtr IVFSQ::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
VectorIndexPtr
IVFSQ::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
faiss::gpu::GpuClonerOptions option;
option.allInGpu = true;
......@@ -67,5 +68,4 @@ VectorIndexPtr IVFSQ::CopyCpuToGpu(const int64_t &device_id, const Config &confi
}
}
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -15,31 +15,31 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <utility>
#include "IndexIVF.h"
namespace zilliz {
namespace knowhere {
class IVFSQ : public IVF {
public:
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {}
public:
explicit IVFSQ(std::shared_ptr<faiss::Index> index) : IVF(std::move(index)) {
}
IVFSQ() = default;
IndexModelPtr
Train(const DatasetPtr &dataset, const Config &config) override;
Train(const DatasetPtr& dataset, const Config& config) override;
VectorIndexPtr
CopyCpuToGpu(const int64_t &device_id, const Config &config) override;
CopyCpuToGpu(const int64_t& device_id, const Config& config) override;
protected:
protected:
VectorIndexPtr
Clone_impl(const std::shared_ptr<faiss::Index> &index) override;
Clone_impl(const std::shared_ptr<faiss::Index>& index) override;
};
} // knowhere
} // zilliz
} // namespace knowhere
......@@ -16,36 +16,35 @@
// specific language governing permissions and limitations
// under the License.
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"
#include "faiss/AutoTune.h"
#include "faiss/gpu/GpuAutoTune.h"
#include "faiss/gpu/GpuIndexIVF.h"
#include "knowhere/adapter/VectorAdapter.h"
#include "knowhere/common/Exception.h"
#include "IndexIVFSQHybrid.h"
#include "faiss/AutoTune.h"
#include "faiss/gpu/GpuAutoTune.h"
namespace zilliz {
namespace knowhere {
IndexModelPtr
IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) {
IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
if (build_cfg != nullptr) {
build_cfg->CheckValid(); // throw exception
build_cfg->CheckValid(); // throw exception
}
gpu_id_ = build_cfg->gpu_id;
GETTENSOR(dataset)
std::stringstream index_type;
index_type << "IVF" << build_cfg->nlist << "," << "SQ8Hybrid";
index_type << "IVF" << build_cfg->nlist << ","
<< "SQ8Hybrid";
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type));
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_resource != nullptr) {
ResScope rs(temp_resource, gpu_id_, true);
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float *) p_data);
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
......@@ -60,12 +59,12 @@ IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) {
}
VectorIndexPtr
IVFSQHybrid::CopyGpuToCpu(const Config &config) {
IVFSQHybrid::CopyGpuToCpu(const Config& config) {
std::lock_guard<std::mutex> lk(mutex_);
if (auto device_idx = std::dynamic_pointer_cast<faiss::IndexIVF>(index_)) {
faiss::Index *device_index = index_.get();
faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index);
faiss::Index* device_index = index_.get();
faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index);
std::shared_ptr<faiss::Index> new_index;
new_index.reset(host_index);
......@@ -77,7 +76,7 @@ IVFSQHybrid::CopyGpuToCpu(const Config &config) {
}
VectorIndexPtr
IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
IVFSQHybrid::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false);
faiss::gpu::GpuClonerOptions option;
......@@ -86,7 +85,7 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
faiss::IndexComposition index_composition;
index_composition.index = index_.get();
index_composition.quantizer = nullptr;
index_composition.mode = 0; // copy all
index_composition.mode = 0; // copy all
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, &index_composition, &option);
......@@ -99,17 +98,13 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
}
void
IVFSQHybrid::LoadImpl(const BinarySet &index_binary) {
FaissBaseIndex::LoadImpl(index_binary); // load on cpu
IVFSQHybrid::LoadImpl(const BinarySet& index_binary) {
FaissBaseIndex::LoadImpl(index_binary); // load on cpu
}
void
IVFSQHybrid::search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) {
IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
const Config& cfg) {
if (gpu_mode) {
GPUIVF::search_impl(n, data, k, distances, labels, cfg);
} else {
......@@ -119,10 +114,10 @@ IVFSQHybrid::search_impl(int64_t n,
}
QuantizerPtr
IVFSQHybrid::LoadQuantizer(const Config &conf) {
IVFSQHybrid::LoadQuantizer(const Config& conf) {
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
if(quantizer_conf->mode != 1) {
if (quantizer_conf->mode != 1) {
KNOWHERE_THROW_MSG("mode only support 1 in this func");
}
}
......@@ -136,7 +131,7 @@ IVFSQHybrid::LoadQuantizer(const Config &conf) {
auto index_composition = new faiss::IndexComposition;
index_composition->index = index_.get();
index_composition->quantizer = nullptr;
index_composition->mode = quantizer_conf->mode; // only 1
index_composition->mode = quantizer_conf->mode; // only 1
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
delete gpu_index;
......@@ -157,10 +152,9 @@ IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) {
KNOWHERE_THROW_MSG("Quantizer type error");
}
faiss::IndexIVF *ivf_index =
dynamic_cast<faiss::IndexIVF *>(index_.get());
faiss::IndexIVF* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
faiss::gpu::GpuIndexFlat *is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat *>(ivf_index->quantizer);
faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat*>(ivf_index->quantizer);
if (is_gpu_flat_index == nullptr) {
delete ivf_index->quantizer;
ivf_index->quantizer = ivf_quantizer->quantizer;
......@@ -169,8 +163,8 @@ IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) {
void
IVFSQHybrid::UnsetQuantizer() {
auto *ivf_index = dynamic_cast<faiss::IndexIVF *>(index_.get());
if(ivf_index == nullptr) {
auto* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
if (ivf_index == nullptr) {
KNOWHERE_THROW_MSG("Index type error");
}
......@@ -178,10 +172,10 @@ IVFSQHybrid::UnsetQuantizer() {
}
void
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr &q, const Config &conf) {
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
if (quantizer_conf != nullptr) {
if(quantizer_conf->mode != 2) {
if (quantizer_conf->mode != 2) {
KNOWHERE_THROW_MSG("mode only support 2 in this func");
}
}
......@@ -195,20 +189,20 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr &q, const Config &conf) {
option.allInGpu = true;
auto ivf_quantizer = std::dynamic_pointer_cast<FaissIVFQuantizer>(q);
if (ivf_quantizer == nullptr) KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer");
if (ivf_quantizer == nullptr)
KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer");
auto index_composition = new faiss::IndexComposition;
index_composition->index = index_.get();
index_composition->quantizer = ivf_quantizer->quantizer;
index_composition->mode = quantizer_conf->mode; // only 2
index_composition->mode = quantizer_conf->mode; // only 2
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
index_.reset(gpu_index);
gpu_mode = true; // all in gpu
gpu_mode = true; // all in gpu
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
}
}
}
} // namespace knowhere
......@@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "VectorIndex.h"
#include <memory>
#include <vector>
#include "VectorIndex.h"
namespace zilliz {
namespace knowhere {
namespace algo {
......@@ -30,18 +30,30 @@ class NsgIndex;
class NSG : public VectorIndex {
public:
explicit NSG(const int64_t& gpu_num):gpu_(gpu_num){}
explicit NSG(const int64_t& gpu_num) : gpu_(gpu_num) {
}
NSG() = default;
IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override;
DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
void Add(const DatasetPtr &dataset, const Config &config) override;
BinarySet Serialize() override;
void Load(const BinarySet &index_binary) override;
int64_t Count() override;
int64_t Dimension() override;
VectorIndexPtr Clone() override;
void Seal() override;
IndexModelPtr
Train(const DatasetPtr& dataset, const Config& config) override;
DatasetPtr
Search(const DatasetPtr& dataset, const Config& config) override;
void
Add(const DatasetPtr& dataset, const Config& config) override;
BinarySet
Serialize() override;
void
Load(const BinarySet& index_binary) override;
int64_t
Count() override;
int64_t
Dimension() override;
VectorIndexPtr
Clone() override;
void
Seal() override;
private:
std::shared_ptr<algo::NsgIndex> index_;
int64_t gpu_;
......@@ -49,5 +61,4 @@ class NSG : public VectorIndex {
using NSGIndexPtr = std::shared_ptr<NSG>();
}
}
} // namespace knowhere
......@@ -15,11 +15,8 @@
// specific language governing permissions and limitations
// under the License.
#pragma once
namespace zilliz {
namespace knowhere {
namespace definition {
......@@ -27,6 +24,5 @@ namespace definition {
#define META_DIM ("dimension")
#define META_K ("k")
} // definition
} // knowhere
} // zilliz
} // namespace definition
} // namespace knowhere
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
文件模式从 100755 更改为 100644
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册