diff --git a/.clang-format b/.clang-format index 819c2d361ed02883dd895ad4a5caa7706d460257..73dc2aeb8ff1d339d0f4c72513e1f421806a2d36 100644 --- a/.clang-format +++ b/.clang-format @@ -18,3 +18,10 @@ BasedOnStyle: Google DerivePointerAlignment: false ColumnLimit: 120 +IndentWidth: 4 +AccessModifierOffset: -3 +AlwaysBreakAfterReturnType: All +AllowShortBlocksOnASingleLine: false +AllowShortFunctionsOnASingleLine: false +AllowShortIfStatementsOnASingleLine: false +AlignTrailingComments: true diff --git a/.clang-tidy b/.clang-tidy index 1f7850efe7c624ae8a22700f247c10126f160d10..1695fa668541bf56dbed074b8860347985116218 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -18,8 +18,10 @@ Checks: 'clang-diagnostic-*,clang-analyzer-*,-clang-analyzer-alpha*,google-*,modernize-*,readability-*' # produce HeaderFilterRegex from cpp/build-support/lint_exclusions.txt with: # echo -n '^('; sed -e 's/*/\.*/g' cpp/build-support/lint_exclusions.txt | tr '\n' '|'; echo ')$' -HeaderFilterRegex: '^(.*cmake-build-debug.*|.*cmake-build-release.*|.*cmake_build.*|.*src/thirdparty.*|.*src/core/thirdparty.*|.*src/grpc.*|)$' +HeaderFilterRegex: '^(.*cmake-build-debug.*|.*cmake-build-release.*|.*cmake_build.*|.*src/core/thirdparty.*|.*thirdparty.*|.*easylogging++.*|.*SqliteMetaImpl.cpp|.*src/grpc.*|.*src/core.*|.*src/wrapper.*)$' AnalyzeTemporaryDtors: true +ChainedConditionalReturn: 1 +ChainedConditionalAssignment: 1 CheckOptions: - key: google-readability-braces-around-statements.ShortStatementLines value: '1' diff --git a/ci/jenkinsfile/milvus_build.groovy b/ci/jenkinsfile/milvus_build.groovy index 44895ed5c95a43104bc91112b53cd5abb24160c6..61307664006028b1969f04fac82bad41a1df8652 100644 --- a/ci/jenkinsfile/milvus_build.groovy +++ b/ci/jenkinsfile/milvus_build.groovy @@ -9,6 +9,7 @@ container('milvus-build-env') { sh "git config --global user.email \"test@zilliz.com\"" sh "git config --global user.name \"test\"" withCredentials([usernamePassword(credentialsId: "${params.JFROG_USER}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + sh "./build.sh -l" sh "export JFROG_ARTFACTORY_URL='${params.JFROG_ARTFACTORY_URL}' && export JFROG_USER_NAME='${USERNAME}' && export JFROG_PASSWORD='${PASSWORD}' && ./build.sh -t ${params.BUILD_TYPE} -j -u -c" } } diff --git a/ci/jenkinsfile/milvus_build_no_ut.groovy b/ci/jenkinsfile/milvus_build_no_ut.groovy index a807ec7ad450b1f15ef4a2cfeb0f84bf7e2023c3..f72089e8c3994ed0543f7d3e567d57d54b50d190 100644 --- a/ci/jenkinsfile/milvus_build_no_ut.groovy +++ b/ci/jenkinsfile/milvus_build_no_ut.groovy @@ -9,6 +9,7 @@ container('milvus-build-env') { sh "git config --global user.email \"test@zilliz.com\"" sh "git config --global user.name \"test\"" withCredentials([usernamePassword(credentialsId: "${params.JFROG_USER}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + sh "./build.sh -l" sh "export JFROG_ARTFACTORY_URL='${params.JFROG_ARTFACTORY_URL}' && export JFROG_USER_NAME='${USERNAME}' && export JFROG_PASSWORD='${PASSWORD}' && ./build.sh -t ${params.BUILD_TYPE} -j" } } diff --git a/ci/main_jenkinsfile b/ci/main_jenkinsfile index 40224fe89467b2a8d590431c989708492d11bfe6..12c6c81cfda6bc5c9d2b2b426dd1f801723b0233 100644 --- a/ci/main_jenkinsfile +++ b/ci/main_jenkinsfile @@ -33,12 +33,30 @@ pipeline { cloud 'build-kubernetes' label 'build' defaultContainer 'jnlp' - containerTemplate { - name 'milvus-build-env' - image 'registry.zilliz.com/milvus/milvus-build-env:v0.12' - ttyEnabled true - command 'cat' - } + yaml """ +apiVersion: v1 +kind: Pod +metadata: + name: milvus-build-env + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.13 + command: + - cat + tty: true + resources: + limits: + memory: "28Gi" + cpu: "10.0" + nvidia.com/gpu: 1 + requests: + memory: "14Gi" + cpu: "5.0" +""" } } stages { diff --git a/ci/main_jenkinsfile_no_ut b/ci/main_jenkinsfile_no_ut index 9948322c3f2964406e895b4b9db860586743ff92..e7382bd1fd2f324c2b31d35e55b64fb3b78f0578 100644 --- a/ci/main_jenkinsfile_no_ut +++ b/ci/main_jenkinsfile_no_ut @@ -33,12 +33,30 @@ pipeline { cloud 'build-kubernetes' label 'build' defaultContainer 'jnlp' - containerTemplate { - name 'milvus-build-env' - image 'registry.zilliz.com/milvus/milvus-build-env:v0.12' - ttyEnabled true - command 'cat' - } + yaml """ +apiVersion: v1 +kind: Pod +metadata: + name: milvus-build-env + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.13 + command: + - cat + tty: true + resources: + limits: + memory: "28Gi" + cpu: "10.0" + nvidia.com/gpu: 1 + requests: + memory: "14Gi" + cpu: "5.0" +""" } } stages { diff --git a/ci/nightly_main_jenkinsfile b/ci/nightly_main_jenkinsfile index 28352e0c83a8ecfd7eeba31a38099353624f8332..add9e00fb4b75bd32c0ec74c645c8b94a2e57a01 100644 --- a/ci/nightly_main_jenkinsfile +++ b/ci/nightly_main_jenkinsfile @@ -33,12 +33,30 @@ pipeline { cloud 'build-kubernetes' label 'build' defaultContainer 'jnlp' - containerTemplate { - name 'milvus-build-env' - image 'registry.zilliz.com/milvus/milvus-build-env:v0.12' - ttyEnabled true - command 'cat' - } + yaml """ +apiVersion: v1 +kind: Pod +metadata: + name: milvus-build-env + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.13 + command: + - cat + tty: true + resources: + limits: + memory: "28Gi" + cpu: "10.0" + nvidia.com/gpu: 1 + requests: + memory: "14Gi" + cpu: "5.0" +""" } } stages { diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index 33d34fcd34cbf02a3d583bab74c4556f150f1ac3..ea4298152b09af183b8ac69551a8c621fb130f8d 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -9,21 +9,32 @@ Please mark all change in change log and use the ticket from JIRA. - MS-572 - Milvus crash when get SIGINT - MS-577 - Unittest Query randomly hung - MS-587 - Count get wrong result after adding vectors and index built immediately +- MS-599 - search wrong result when table created with metric_type: IP +- MS-601 - Docker logs error caused by get CPUTemperature error +- MS-622 - Delete vectors should be failed if date range is invalid +- MS-620 - Get table row counts display wrong error code ## Improvement - MS-552 - Add and change the easylogging library - MS-553 - Refine cache code -- MS-557 - Merge Log.h +- MS-555 - Remove old scheduler - MS-556 - Add Job Definition in Scheduler +- MS-557 - Merge Log.h - MS-558 - Refine status code - MS-562 - Add JobMgr and TaskCreator in Scheduler - MS-566 - Refactor cmake -- MS-555 - Remove old scheduler - MS-574 - Milvus configuration refactor - MS-578 - Make sure milvus5.0 don't crack 0.3.1 data - MS-585 - Update namespace in scheduler +- MS-606 - Speed up result reduce +- MS-608 - Update TODO names +- MS-609 - Update task construct function +- MS-611 - Add resources validity check in ResourceMgr +- MS-619 - Add optimizer class in scheduler +- MS-614 - Preload table at startup ## New Feature +- MS-627 - Integrate new index: IVFSQHybrid ## Task - MS-554 - Change license to Apache 2.0 @@ -33,6 +44,9 @@ Please mark all change in change log and use the ticket from JIRA. - MS-575 - Add Clang-format & Clang-tidy & Cpplint - MS-586 - Remove BUILD_FAISS_WITH_MKL option - MS-590 - Refine cmake code to support cpplint +- MS-600 - Reconstruct unittest code +- MS-602 - Remove zilliz namespace +- MS-610 - Change error code base value from hex to decimal # Milvus 0.4.0 (2019-09-12) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index cc807d43ed4cfa238283566f8a8fb9bf7f1c12b8..d2092eb018444d3e207b42596d5903bc11ced4dc 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -189,7 +189,7 @@ add_custom_target(lint --exclude_globs ${LINT_EXCLUSIONS_FILE} --source_dir - ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR} ${MILVUS_LINT_QUIET}) # diff --git a/cpp/README.md b/cpp/README.md index f5f77d25f8bcd9edb5bba76b354fd2a0fa9bccd4..6e81b567ffd15b94fb6cd61b3a162723f933703d 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -105,15 +105,31 @@ please reinstall CMake with curl: ``` ##### code format and linting - +Install clang-format and clang-tidy ```shell CentOS 7: $ yum install clang -Ubuntu 16.04 or 18.04: -$ sudo apt-get install clang-format clang-tidy - +Ubuntu 16.04: +$ sudo apt-get install clang-tidy +$ sudo su +$ wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - +$ apt-add-repository "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-6.0 main" +$ apt-get update +$ apt-get install clang-format-6.0 +Ubuntu 18.04: +$ sudo apt-get install clang-tidy clang-format + +$ rm cmake_build/CMakeCache.txt +``` +Check code style +```shell $ ./build.sh -l ``` +To format the code +```shell +$ cd cmake_build +$ make clang-format +``` ##### Run unit test @@ -122,13 +138,14 @@ $ ./build.sh -u ``` ##### Run code coverage - +Install lcov ```shell CentOS 7: $ yum install lcov Ubuntu 16.04 or 18.04: $ sudo apt-get install lcov - +``` +```shell $ ./build.sh -u -c ``` diff --git a/cpp/build-support/code_style_clion.xml b/cpp/build-support/code_style_clion.xml new file mode 100644 index 0000000000000000000000000000000000000000..f2edeec3b640cb4f7d29cf065185e690ea1e9c11 --- /dev/null +++ b/cpp/build-support/code_style_clion.xml @@ -0,0 +1,38 @@ + + + + + + \ No newline at end of file diff --git a/cpp/build-support/lint_exclusions.txt b/cpp/build-support/lint_exclusions.txt index 27bd780f42c9261f14129cc23836a365c60fdcdd..6ac690f661f56473d28d1bf476ccef9d7a944c81 100644 --- a/cpp/build-support/lint_exclusions.txt +++ b/cpp/build-support/lint_exclusions.txt @@ -1,8 +1,9 @@ *cmake-build-debug* *cmake-build-release* *cmake_build* -*src/thirdparty* *src/core/thirdparty* -*src/grpc* +*thirdparty* *easylogging++* -*SqliteMetaImpl.cpp \ No newline at end of file +*SqliteMetaImpl.cpp +*src/grpc* +*milvus/include* \ No newline at end of file diff --git a/cpp/build.sh b/cpp/build.sh index 4087db45928118ed23f0cc8caab1100c54abe799..648206f9ae62c77b2e98338769904ee17f9516f4 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -99,21 +99,26 @@ if [[ ${RUN_CPPLINT} == "ON" ]]; then # cpplint check make lint if [ $? -ne 0 ]; then - echo "ERROR! cpplint check not pass" + echo "ERROR! cpplint check failed" exit 1 fi + echo "cpplint check passed!" + # clang-format check make check-clang-format if [ $? -ne 0 ]; then echo "ERROR! clang-format check failed" exit 1 fi - # clang-tidy check - make check-clang-tidy - if [ $? -ne 0 ]; then - echo "ERROR! clang-tidy check failed" - exit 1 - fi + echo "clang-format check passed!" + +# # clang-tidy check +# make check-clang-tidy +# if [ $? -ne 0 ]; then +# echo "ERROR! clang-tidy check failed" +# exit 1 +# fi +# echo "clang-tidy check passed!" else # compile and build make -j 4 || exit 1 diff --git a/cpp/conf/server_config.template b/cpp/conf/server_config.template index b49b5855f5c97b3dad15ed9499a8cc318a2c6389..2f2f699e09fcee5b4f585b874d5e0f8a24c03913 100644 --- a/cpp/conf/server_config.template +++ b/cpp/conf/server_config.template @@ -11,25 +11,30 @@ db_config: secondary_path: # path used to store data only, split by semicolon backend_url: sqlite://:@:/ # URI format: dialect://username:password@host:port/database - # Keep 'dialect://:@:/', and replace other texts with real values. + # Keep 'dialect://:@:/', and replace other texts with real values # Replace 'dialect' with 'mysql' or 'sqlite' insert_buffer_size: 4 # GB, maximum insert buffer size allowed + # sum of insert_buffer_size and cpu_cache_capacity cannot exceed total memory build_index_gpu: 0 # gpu id used for building index + preload_table: # preload data at startup, '*' means load all tables, empty value means no preload + # you can specify preload tables like this: table1,table2,table3 + metric_config: enable_monitor: false # enable monitoring or not collector: prometheus # prometheus prometheus_config: - port: 8080 # port prometheus used to fetch metrics + port: 8080 # port prometheus uses to fetch metrics cache_config: - cpu_mem_capacity: 16 # GB, CPU memory used for cache - cpu_mem_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered + cpu_cache_capacity: 16 # GB, CPU memory used for cache + cpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered cache_insert_data: false # whether to load inserted data into cache engine_config: - blas_threshold: 20 + use_blas_threshold: 20 # if nq < use_blas_threshold, use SSE, faster with fluctuated response times + # if nq >= use_blas_threshold, use OpenBlas, slower with stable response times resource_config: resource_pool: diff --git a/cpp/coverage.sh b/cpp/coverage.sh index 84459fcc84e51d504ff6e3d8314600c6eafcc1e8..8a7e5f52a16a7f3204f40bcf04cde274ac7562be 100755 --- a/cpp/coverage.sh +++ b/cpp/coverage.sh @@ -70,11 +70,11 @@ fi for test in `ls ${DIR_UNITTEST}`; do echo $test case ${test} in - db_test) - # set run args for db_test + test_db) + # set run args for test_db args="mysql://${MYSQL_USER_NAME}:${MYSQL_PASSWORD}@${MYSQL_HOST}:${MYSQL_PORT}/${MYSQL_DB_NAME}" ;; - *_test) + *test_*) args="" ;; esac @@ -104,7 +104,7 @@ ${LCOV_CMD} -r "${FILE_INFO_OUTPUT}" -o "${FILE_INFO_OUTPUT_NEW}" \ "src/metrics/MetricBase.h"\ "src/server/Server.cpp"\ "src/server/DBWrapper.cpp"\ - "src/server/grpc_impl/GrpcMilvusServer.cpp"\ + "src/server/grpc_impl/GrpcServer.cpp"\ "src/utils/easylogging++.h"\ "src/utils/easylogging++.cc"\ diff --git a/cpp/src/CMakeLists.txt b/cpp/src/CMakeLists.txt index 6ee3966c4ffdab88020a4fab70e345565ea4e61b..0005edbaf7daec85fae4f0026c0d61c3eb256e3c 100644 --- a/cpp/src/CMakeLists.txt +++ b/cpp/src/CMakeLists.txt @@ -20,29 +20,24 @@ include_directories(${MILVUS_SOURCE_DIR}) include_directories(${MILVUS_ENGINE_SRC}) -add_subdirectory(core) +include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include) +include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-status) +include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus) +#this statement must put here, since the CORE_INCLUDE_DIRS is defined in code/CMakeList.txt +add_subdirectory(core) set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE) foreach (dir ${CORE_INCLUDE_DIRS}) include_directories(${dir}) endforeach () aux_source_directory(${MILVUS_ENGINE_SRC}/cache cache_files) - aux_source_directory(${MILVUS_ENGINE_SRC}/config config_files) - +aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db db_main_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/engine db_engine_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/insert db_insert_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler db_scheduler_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/context db_scheduler_context_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/task db_scheduler_task_files) -set(db_scheduler_files - ${db_scheduler_files} - ${db_scheduler_context_files} - ${db_scheduler_task_files} - ) set(grpc_service_files ${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.grpc.pb.cc @@ -51,12 +46,11 @@ set(grpc_service_files ${MILVUS_ENGINE_SRC}/grpc/gen-status/status.pb.cc ) -aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files) - aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler scheduler_main_files) aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/action scheduler_action_files) aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/event scheduler_event_files) aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/job scheduler_job_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/optimizer scheduler_optimizer_files) aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/resource scheduler_resource_files) aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/task scheduler_task_files) set(scheduler_files @@ -64,15 +58,14 @@ set(scheduler_files ${scheduler_action_files} ${scheduler_event_files} ${scheduler_job_files} + ${scheduler_optimizer_files} ${scheduler_resource_files} ${scheduler_task_files} ) aux_source_directory(${MILVUS_ENGINE_SRC}/server server_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/grpc_impl grpc_server_files) - aux_source_directory(${MILVUS_ENGINE_SRC}/utils utils_files) - aux_source_directory(${MILVUS_ENGINE_SRC}/wrapper wrapper_files) set(engine_files @@ -82,16 +75,11 @@ set(engine_files ${db_engine_files} ${db_insert_files} ${db_meta_files} - ${db_scheduler_files} ${metrics_files} ${utils_files} ${wrapper_files} ) -include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include") -include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-status) -include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus) - set(client_grpc_lib grpcpp_channelz grpc++ @@ -112,6 +100,12 @@ set(boost_lib boost_serialization_static ) +set(cuda_lib + ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so + cudart + cublas + ) + set(third_party_libs sqlite ${client_grpc_lib} @@ -123,17 +117,15 @@ set(third_party_libs snappy zlib zstd - cudart - cublas + ${cuda_lib} mysqlpp - ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so - cudart ) + if (MILVUS_ENABLE_PROFILING STREQUAL "ON") set(third_party_libs ${third_party_libs} - gperftools - libunwind - ) + gperftools + libunwind + ) endif () link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") @@ -141,7 +133,6 @@ set(engine_libs pthread libgomp.a libgfortran.a - ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so ) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") @@ -152,7 +143,11 @@ if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") endif () cuda_add_library(milvus_engine STATIC ${engine_files}) -target_link_libraries(milvus_engine ${engine_libs} knowhere ${third_party_libs}) +target_link_libraries(milvus_engine + knowhere + ${engine_libs} + ${third_party_libs} + ) add_library(metrics STATIC ${metrics_files}) @@ -180,7 +175,9 @@ add_executable(milvus_server ${utils_files} ) -target_link_libraries(milvus_server ${server_libs}) +target_link_libraries(milvus_server + ${server_libs} + ) install(TARGETS milvus_server DESTINATION bin) diff --git a/cpp/src/cache/Cache.h b/cpp/src/cache/Cache.h index b526e9339f7197ddc7f75d1c0384c48b23878c6c..62a7f13ca854f8b55664a1b72079d3aea5f057b9 100644 --- a/cpp/src/cache/Cache.h +++ b/cpp/src/cache/Cache.h @@ -15,55 +15,66 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "LRU.h" #include "utils/Log.h" -#include -#include #include +#include #include +#include -namespace zilliz { namespace milvus { namespace cache { -template +template class Cache { public: - //mem_capacity, units:GB + // mem_capacity, units:GB Cache(int64_t capacity_gb, uint64_t cache_max_count); ~Cache() = default; - int64_t usage() const { + int64_t + usage() const { return usage_; } - int64_t capacity() const { + int64_t + capacity() const { return capacity_; - } //unit: BYTE - void set_capacity(int64_t capacity); //unit: BYTE + } // unit: BYTE + void + set_capacity(int64_t capacity); // unit: BYTE - double freemem_percent() const { + double + freemem_percent() const { return freemem_percent_; } - void set_freemem_percent(double percent) { + void + set_freemem_percent(double percent) { freemem_percent_ = percent; } - size_t size() const; - bool exists(const std::string &key); - ItemObj get(const std::string &key); - void insert(const std::string &key, const ItemObj &item); - void erase(const std::string &key); - void print(); - void clear(); + size_t + size() const; + bool + exists(const std::string& key); + ItemObj + get(const std::string& key); + void + insert(const std::string& key, const ItemObj& item); + void + erase(const std::string& key); + void + print(); + void + clear(); private: - void free_memory(); + void + free_memory(); private: int64_t usage_; @@ -74,8 +85,7 @@ class Cache { mutable std::mutex mutex_; }; -} // namespace cache -} // namespace milvus -} // namespace zilliz +} // namespace cache +} // namespace milvus #include "cache/Cache.inl" diff --git a/cpp/src/cache/Cache.inl b/cpp/src/cache/Cache.inl index ebd8f3c25d2b23fd5adce6e3d50e36efb26a7e0e..2ab6b3ffa66480612e92056d62eefbe03eb809c7 100644 --- a/cpp/src/cache/Cache.inl +++ b/cpp/src/cache/Cache.inl @@ -17,7 +17,7 @@ -namespace zilliz { + namespace milvus { namespace cache { @@ -190,5 +190,5 @@ Cache::print() { } // namespace cache } // namespace milvus -} // namespace zilliz + diff --git a/cpp/src/cache/CacheMgr.h b/cpp/src/cache/CacheMgr.h index 53004d10b23b41a90a7090a114f2ad730e05bff9..f5e04f650968552c01664139b8442b56fddf3868 100644 --- a/cpp/src/cache/CacheMgr.h +++ b/cpp/src/cache/CacheMgr.h @@ -15,40 +15,48 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "Cache.h" -#include "utils/Log.h" #include "metrics/Metrics.h" +#include "utils/Log.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace cache { -template +template class CacheMgr { public: - virtual uint64_t ItemCount() const; + virtual uint64_t + ItemCount() const; - virtual bool ItemExists(const std::string &key); + virtual bool + ItemExists(const std::string& key); - virtual ItemObj GetItem(const std::string &key); + virtual ItemObj + GetItem(const std::string& key); - virtual void InsertItem(const std::string &key, const ItemObj &data); + virtual void + InsertItem(const std::string& key, const ItemObj& data); - virtual void EraseItem(const std::string &key); + virtual void + EraseItem(const std::string& key); - virtual void PrintInfo(); + virtual void + PrintInfo(); - virtual void ClearCache(); + virtual void + ClearCache(); - int64_t CacheUsage() const; - int64_t CacheCapacity() const; - void SetCapacity(int64_t capacity); + int64_t + CacheUsage() const; + int64_t + CacheCapacity() const; + void + SetCapacity(int64_t capacity); protected: CacheMgr(); @@ -59,8 +67,7 @@ class CacheMgr { CachePtr cache_; }; -} // namespace cache -} // namespace milvus -} // namespace zilliz +} // namespace cache +} // namespace milvus #include "cache/CacheMgr.inl" diff --git a/cpp/src/cache/CacheMgr.inl b/cpp/src/cache/CacheMgr.inl index b0c47b3dc37f54b6b336fe26f3b292d45ba5697f..23b2f0df743b7b237887fec8d322b413fd34f34f 100644 --- a/cpp/src/cache/CacheMgr.inl +++ b/cpp/src/cache/CacheMgr.inl @@ -16,7 +16,7 @@ // under the License. -namespace zilliz { + namespace milvus { namespace cache { @@ -142,4 +142,4 @@ CacheMgr::SetCapacity(int64_t capacity) { } // namespace cache } // namespace milvus -} // namespace zilliz + diff --git a/cpp/src/cache/CpuCacheMgr.cpp b/cpp/src/cache/CpuCacheMgr.cpp index d26004b2bef5a843700a288075f825022f5db565..7cfe59e72e58b3bcd91b75a00df87e8f4e030878 100644 --- a/cpp/src/cache/CpuCacheMgr.cpp +++ b/cpp/src/cache/CpuCacheMgr.cpp @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. - #include "cache/CpuCacheMgr.h" #include "server/Config.h" #include "utils/Log.h" #include -namespace zilliz { namespace milvus { namespace cache { @@ -31,38 +29,38 @@ constexpr int64_t unit = 1024 * 1024 * 1024; } CpuCacheMgr::CpuCacheMgr() { - server::Config &config = server::Config::GetInstance(); + server::Config& config = server::Config::GetInstance(); Status s; - int32_t cpu_mem_cap; - s = config.GetCacheConfigCpuMemCapacity(cpu_mem_cap); + int32_t cpu_cache_cap; + s = config.GetCacheConfigCpuCacheCapacity(cpu_cache_cap); if (!s.ok()) { SERVER_LOG_ERROR << s.message(); } - int64_t cap = cpu_mem_cap * unit; + int64_t cap = cpu_cache_cap * unit; cache_ = std::make_shared>(cap, 1UL << 32); - float cpu_mem_threshold; - s = config.GetCacheConfigCpuMemThreshold(cpu_mem_threshold); + float cpu_cache_threshold; + s = config.GetCacheConfigCpuCacheThreshold(cpu_cache_threshold); if (!s.ok()) { SERVER_LOG_ERROR << s.message(); } - if (cpu_mem_threshold > 0.0 && cpu_mem_threshold <= 1.0) { - cache_->set_freemem_percent(cpu_mem_threshold); + if (cpu_cache_threshold > 0.0 && cpu_cache_threshold <= 1.0) { + cache_->set_freemem_percent(cpu_cache_threshold); } else { - SERVER_LOG_ERROR << "Invalid cpu_mem_threshold: " << cpu_mem_threshold - << ", by default set to " << cache_->freemem_percent(); + SERVER_LOG_ERROR << "Invalid cpu_cache_threshold: " << cpu_cache_threshold << ", by default set to " + << cache_->freemem_percent(); } } -CpuCacheMgr * +CpuCacheMgr* CpuCacheMgr::GetInstance() { static CpuCacheMgr s_mgr; return &s_mgr; } engine::VecIndexPtr -CpuCacheMgr::GetIndex(const std::string &key) { +CpuCacheMgr::GetIndex(const std::string& key) { DataObjPtr obj = GetItem(key); if (obj != nullptr) { return obj->data(); @@ -71,6 +69,5 @@ CpuCacheMgr::GetIndex(const std::string &key) { return nullptr; } -} // namespace cache -} // namespace milvus -} // namespace zilliz +} // namespace cache +} // namespace milvus diff --git a/cpp/src/cache/CpuCacheMgr.h b/cpp/src/cache/CpuCacheMgr.h index 32a83c4cc0dd6a510f3d9386e410db5651951d8b..2f8c50991a9e50eadf60935bbcc40d3d78de3d4d 100644 --- a/cpp/src/cache/CpuCacheMgr.h +++ b/cpp/src/cache/CpuCacheMgr.h @@ -20,10 +20,9 @@ #include "CacheMgr.h" #include "DataObj.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace cache { @@ -32,12 +31,13 @@ class CpuCacheMgr : public CacheMgr { CpuCacheMgr(); public: - //TODO: use smart pointer instead - static CpuCacheMgr *GetInstance(); + // TODO(myh): use smart pointer instead + static CpuCacheMgr* + GetInstance(); - engine::VecIndexPtr GetIndex(const std::string &key); + engine::VecIndexPtr + GetIndex(const std::string& key); }; -} // namespace cache -} // namespace milvus -} // namespace zilliz +} // namespace cache +} // namespace milvus diff --git a/cpp/src/cache/DataObj.h b/cpp/src/cache/DataObj.h index 6ed3256eee1bb22f7a9535af666deceab4509a9b..eb58b708250cbc221361f2b782ac8042a9a8d8a4 100644 --- a/cpp/src/cache/DataObj.h +++ b/cpp/src/cache/DataObj.h @@ -15,37 +15,35 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "src/wrapper/VecIndex.h" #include -namespace zilliz { namespace milvus { namespace cache { class DataObj { public: - explicit DataObj(const engine::VecIndexPtr &index) - : index_(index) { + explicit DataObj(const engine::VecIndexPtr& index) : index_(index) { } - DataObj(const engine::VecIndexPtr &index, int64_t size) - : index_(index), - size_(size) { + DataObj(const engine::VecIndexPtr& index, int64_t size) : index_(index), size_(size) { } - engine::VecIndexPtr data() { + engine::VecIndexPtr + data() { return index_; } - const engine::VecIndexPtr &data() const { + const engine::VecIndexPtr& + data() const { return index_; } - int64_t size() const { + int64_t + size() const { if (index_ == nullptr) { return 0; } @@ -64,6 +62,5 @@ class DataObj { using DataObjPtr = std::shared_ptr; -} // namespace cache -} // namespace milvus -} // namespace zilliz +} // namespace cache +} // namespace milvus diff --git a/cpp/src/cache/GpuCacheMgr.cpp b/cpp/src/cache/GpuCacheMgr.cpp index ad68a6ebef47ee903e954796abbb05affd4c3192..a9eff6d3c3ceb295675ff3ce35d6be000dceb308 100644 --- a/cpp/src/cache/GpuCacheMgr.cpp +++ b/cpp/src/cache/GpuCacheMgr.cpp @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. - #include "cache/GpuCacheMgr.h" -#include "utils/Log.h" #include "server/Config.h" +#include "utils/Log.h" #include #include -namespace zilliz { namespace milvus { namespace cache { @@ -35,31 +33,31 @@ constexpr int64_t G_BYTE = 1024 * 1024 * 1024; } GpuCacheMgr::GpuCacheMgr() { - server::Config &config = server::Config::GetInstance(); + server::Config& config = server::Config::GetInstance(); Status s; - int32_t gpu_mem_cap; - s = config.GetCacheConfigGpuMemCapacity(gpu_mem_cap); + int32_t gpu_cache_cap; + s = config.GetCacheConfigGpuCacheCapacity(gpu_cache_cap); if (!s.ok()) { SERVER_LOG_ERROR << s.message(); } - int32_t cap = gpu_mem_cap * G_BYTE; + int32_t cap = gpu_cache_cap * G_BYTE; cache_ = std::make_shared>(cap, 1UL << 32); float gpu_mem_threshold; - s = config.GetCacheConfigGpuMemThreshold(gpu_mem_threshold); + s = config.GetCacheConfigGpuCacheThreshold(gpu_mem_threshold); if (!s.ok()) { SERVER_LOG_ERROR << s.message(); } if (gpu_mem_threshold > 0.0 && gpu_mem_threshold <= 1.0) { cache_->set_freemem_percent(gpu_mem_threshold); } else { - SERVER_LOG_ERROR << "Invalid gpu_mem_threshold: " << gpu_mem_threshold - << ", by default set to " << cache_->freemem_percent(); + SERVER_LOG_ERROR << "Invalid gpu_mem_threshold: " << gpu_mem_threshold << ", by default set to " + << cache_->freemem_percent(); } } -GpuCacheMgr * +GpuCacheMgr* GpuCacheMgr::GetInstance(uint64_t gpu_id) { if (instance_.find(gpu_id) == instance_.end()) { std::lock_guard lock(mutex_); @@ -74,7 +72,7 @@ GpuCacheMgr::GetInstance(uint64_t gpu_id) { } engine::VecIndexPtr -GpuCacheMgr::GetIndex(const std::string &key) { +GpuCacheMgr::GetIndex(const std::string& key) { DataObjPtr obj = GetItem(key); if (obj != nullptr) { return obj->data(); @@ -83,6 +81,5 @@ GpuCacheMgr::GetIndex(const std::string &key) { return nullptr; } -} // namespace cache -} // namespace milvus -} // namespace zilliz +} // namespace cache +} // namespace milvus diff --git a/cpp/src/cache/GpuCacheMgr.h b/cpp/src/cache/GpuCacheMgr.h index 843a5ff67d7a5479c9c54e677345701ae7125fee..8c0d4b95026ab4aea4684b55d593b2d8b3c20b4c 100644 --- a/cpp/src/cache/GpuCacheMgr.h +++ b/cpp/src/cache/GpuCacheMgr.h @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. - #include "CacheMgr.h" #include "DataObj.h" -#include #include #include +#include -namespace zilliz { namespace milvus { namespace cache { @@ -34,15 +32,16 @@ class GpuCacheMgr : public CacheMgr { public: GpuCacheMgr(); - static GpuCacheMgr *GetInstance(uint64_t gpu_id); + static GpuCacheMgr* + GetInstance(uint64_t gpu_id); - engine::VecIndexPtr GetIndex(const std::string &key); + engine::VecIndexPtr + GetIndex(const std::string& key); private: static std::mutex mutex_; static std::unordered_map instance_; }; -} // namespace cache -} // namespace milvus -} // namespace zilliz +} // namespace cache +} // namespace milvus diff --git a/cpp/src/cache/LRU.h b/cpp/src/cache/LRU.h index 5446dd0f14ec909c761f05f125fcc13d3d4bfcaa..9b0eac060854e5d24871c8217308224b71b39b72 100644 --- a/cpp/src/cache/LRU.h +++ b/cpp/src/cache/LRU.h @@ -15,20 +15,18 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include -#include #include +#include #include +#include #include -namespace zilliz { namespace milvus { namespace cache { -template +template class LRU { public: typedef typename std::pair key_value_pair_t; @@ -38,7 +36,8 @@ class LRU { explicit LRU(size_t max_size) : max_size_(max_size) { } - void put(const key_t &key, const value_t &value) { + void + put(const key_t& key, const value_t& value) { auto it = cache_items_map_.find(key); cache_items_list_.push_front(key_value_pair_t(key, value)); if (it != cache_items_map_.end()) { @@ -55,7 +54,8 @@ class LRU { } } - const value_t &get(const key_t &key) { + const value_t& + get(const key_t& key) { auto it = cache_items_map_.find(key); if (it == cache_items_map_.end()) { throw std::range_error("There is no such key in cache"); @@ -65,7 +65,8 @@ class LRU { } } - void erase(const key_t &key) { + void + erase(const key_t& key) { auto it = cache_items_map_.find(key); if (it != cache_items_map_.end()) { cache_items_list_.erase(it->second); @@ -73,32 +74,39 @@ class LRU { } } - bool exists(const key_t &key) const { + bool + exists(const key_t& key) const { return cache_items_map_.find(key) != cache_items_map_.end(); } - size_t size() const { + size_t + size() const { return cache_items_map_.size(); } - list_iterator_t begin() { + list_iterator_t + begin() { iter_ = cache_items_list_.begin(); return iter_; } - list_iterator_t end() { + list_iterator_t + end() { return cache_items_list_.end(); } - reverse_list_iterator_t rbegin() { + reverse_list_iterator_t + rbegin() { return cache_items_list_.rbegin(); } - reverse_list_iterator_t rend() { + reverse_list_iterator_t + rend() { return cache_items_list_.rend(); } - void clear() { + void + clear() { cache_items_list_.clear(); cache_items_map_.clear(); } @@ -110,7 +118,5 @@ class LRU { list_iterator_t iter_; }; -} // namespace cache -} // namespace milvus -} // namespace zilliz - +} // namespace cache +} // namespace milvus diff --git a/cpp/src/config/ConfigMgr.cpp b/cpp/src/config/ConfigMgr.cpp index a7cfdc9fdaa9f796e39c2d0412a472b8b812d4f0..5889f52fac2987ab7a8c7830c176acdf2d0417c6 100644 --- a/cpp/src/config/ConfigMgr.cpp +++ b/cpp/src/config/ConfigMgr.cpp @@ -18,16 +18,14 @@ #include "config/ConfigMgr.h" #include "YamlConfigMgr.h" -namespace zilliz { namespace milvus { namespace server { -ConfigMgr * +ConfigMgr* ConfigMgr::GetInstance() { static YamlConfigMgr mgr; return &mgr; } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/config/ConfigMgr.h b/cpp/src/config/ConfigMgr.h index 96a63aa2c8e1a784bf828161a7a5447cba08b280..40753beaa9766e536fb2c34cecfd7166022add49 100644 --- a/cpp/src/config/ConfigMgr.h +++ b/cpp/src/config/ConfigMgr.h @@ -17,12 +17,11 @@ #pragma once -#include "utils/Error.h" #include "ConfigNode.h" +#include "utils/Error.h" #include -namespace zilliz { namespace milvus { namespace server { @@ -42,16 +41,21 @@ namespace server { class ConfigMgr { public: - static ConfigMgr *GetInstance(); - - virtual ErrorCode LoadConfigFile(const std::string &filename) = 0; - virtual void Print() const = 0;//will be deleted - virtual std::string DumpString() const = 0; - - virtual const ConfigNode &GetRootNode() const = 0; - virtual ConfigNode &GetRootNode() = 0; + static ConfigMgr* + GetInstance(); + + virtual ErrorCode + LoadConfigFile(const std::string& filename) = 0; + virtual void + Print() const = 0; // will be deleted + virtual std::string + DumpString() const = 0; + + virtual const ConfigNode& + GetRootNode() const = 0; + virtual ConfigNode& + GetRootNode() = 0; }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/config/ConfigNode.cpp b/cpp/src/config/ConfigNode.cpp index 87b080d5722e9276dc48359faa4862fbd771d795..cf148e4d29e57341276d1e1a95df99f320046995 100644 --- a/cpp/src/config/ConfigNode.cpp +++ b/cpp/src/config/ConfigNode.cpp @@ -19,51 +19,50 @@ #include "utils/Error.h" #include "utils/Log.h" +#include #include #include -#include -namespace zilliz { namespace milvus { namespace server { void -ConfigNode::Combine(const ConfigNode &target) { - const std::map &kv = target.GetConfig(); +ConfigNode::Combine(const ConfigNode& target) { + const std::map& kv = target.GetConfig(); for (auto itr = kv.begin(); itr != kv.end(); ++itr) { config_[itr->first] = itr->second; } - const std::map > &sequences = target.GetSequences(); + const std::map >& sequences = target.GetSequences(); for (auto itr = sequences.begin(); itr != sequences.end(); ++itr) { sequences_[itr->first] = itr->second; } - const std::map &children = target.GetChildren(); + const std::map& children = target.GetChildren(); for (auto itr = children.begin(); itr != children.end(); ++itr) { children_[itr->first] = itr->second; } } -//key/value pair config +// key/value pair config void -ConfigNode::SetValue(const std::string &key, const std::string &value) { +ConfigNode::SetValue(const std::string& key, const std::string& value) { config_[key] = value; } std::string -ConfigNode::GetValue(const std::string ¶m_key, const std::string &default_val) const { +ConfigNode::GetValue(const std::string& param_key, const std::string& default_val) const { auto ref = config_.find(param_key); if (ref != config_.end()) { return ref->second; } - //THROW_UNEXPECTED_ERROR("Can't find parameter key: " + param_key); + // THROW_UNEXPECTED_ERROR("Can't find parameter key: " + param_key); return default_val; } bool -ConfigNode::GetBoolValue(const std::string ¶m_key, bool default_val) const { +ConfigNode::GetBoolValue(const std::string& param_key, bool default_val) const { std::string val = GetValue(param_key); if (!val.empty()) { std::transform(val.begin(), val.end(), val.begin(), ::tolower); @@ -74,17 +73,17 @@ ConfigNode::GetBoolValue(const std::string ¶m_key, bool default_val) const { } int32_t -ConfigNode::GetInt32Value(const std::string ¶m_key, int32_t default_val) const { +ConfigNode::GetInt32Value(const std::string& param_key, int32_t default_val) const { std::string val = GetValue(param_key); if (!val.empty()) { - return (int32_t) std::strtol(val.c_str(), nullptr, 10); + return (int32_t)std::strtol(val.c_str(), nullptr, 10); } else { return default_val; } } int64_t -ConfigNode::GetInt64Value(const std::string ¶m_key, int64_t default_val) const { +ConfigNode::GetInt64Value(const std::string& param_key, int64_t default_val) const { std::string val = GetValue(param_key); if (!val.empty()) { return std::strtol(val.c_str(), nullptr, 10); @@ -94,7 +93,7 @@ ConfigNode::GetInt64Value(const std::string ¶m_key, int64_t default_val) con } float -ConfigNode::GetFloatValue(const std::string ¶m_key, float default_val) const { +ConfigNode::GetFloatValue(const std::string& param_key, float default_val) const { std::string val = GetValue(param_key); if (!val.empty()) { return std::strtof(val.c_str(), nullptr); @@ -104,7 +103,7 @@ ConfigNode::GetFloatValue(const std::string ¶m_key, float default_val) const } double -ConfigNode::GetDoubleValue(const std::string ¶m_key, double default_val) const { +ConfigNode::GetDoubleValue(const std::string& param_key, double default_val) const { std::string val = GetValue(param_key); if (!val.empty()) { return std::strtod(val.c_str(), nullptr); @@ -113,7 +112,7 @@ ConfigNode::GetDoubleValue(const std::string ¶m_key, double default_val) con } } -const std::map & +const std::map& ConfigNode::GetConfig() const { return config_; } @@ -123,14 +122,14 @@ ConfigNode::ClearConfig() { config_.clear(); } -//key/object config +// key/object config void -ConfigNode::AddChild(const std::string &type_name, const ConfigNode &config) { +ConfigNode::AddChild(const std::string& type_name, const ConfigNode& config) { children_[type_name] = config; } ConfigNode -ConfigNode::GetChild(const std::string &type_name) const { +ConfigNode::GetChild(const std::string& type_name) const { auto ref = children_.find(type_name); if (ref != children_.end()) { return ref->second; @@ -140,20 +139,20 @@ ConfigNode::GetChild(const std::string &type_name) const { return nc; } -ConfigNode & -ConfigNode::GetChild(const std::string &type_name) { +ConfigNode& +ConfigNode::GetChild(const std::string& type_name) { return children_[type_name]; } void -ConfigNode::GetChildren(ConfigNodeArr &arr) const { +ConfigNode::GetChildren(ConfigNodeArr& arr) const { arr.clear(); for (auto ref : children_) { arr.push_back(ref.second); } } -const std::map & +const std::map& ConfigNode::GetChildren() const { return children_; } @@ -163,14 +162,14 @@ ConfigNode::ClearChildren() { children_.clear(); } -//key/sequence config +// key/sequence config void -ConfigNode::AddSequenceItem(const std::string &key, const std::string &item) { +ConfigNode::AddSequenceItem(const std::string& key, const std::string& item) { sequences_[key].push_back(item); } std::vector -ConfigNode::GetSequence(const std::string &key) const { +ConfigNode::GetSequence(const std::string& key) const { auto itr = sequences_.find(key); if (itr != sequences_.end()) { return itr->second; @@ -180,7 +179,7 @@ ConfigNode::GetSequence(const std::string &key) const { } } -const std::map > & +const std::map >& ConfigNode::GetSequences() const { return sequences_; } @@ -191,40 +190,40 @@ ConfigNode::ClearSequences() { } void -ConfigNode::PrintAll(const std::string &prefix) const { - for (auto &elem : config_) { +ConfigNode::PrintAll(const std::string& prefix) const { + for (auto& elem : config_) { SERVER_LOG_INFO << prefix << elem.first + ": " << elem.second; } - for (auto &elem : sequences_) { + for (auto& elem : sequences_) { SERVER_LOG_INFO << prefix << elem.first << ": "; - for (auto &str : elem.second) { + for (auto& str : elem.second) { SERVER_LOG_INFO << prefix << " - " << str; } } - for (auto &elem : children_) { + for (auto& elem : children_) { SERVER_LOG_INFO << prefix << elem.first << ": "; elem.second.PrintAll(prefix + " "); } } std::string -ConfigNode::DumpString(const std::string &prefix) const { +ConfigNode::DumpString(const std::string& prefix) const { std::stringstream str_buffer; const std::string endl = "\n"; - for (auto &elem : config_) { + for (auto& elem : config_) { str_buffer << prefix << elem.first << ": " << elem.second << endl; } - for (auto &elem : sequences_) { + for (auto& elem : sequences_) { str_buffer << prefix << elem.first << ": " << endl; - for (auto &str : elem.second) { + for (auto& str : elem.second) { str_buffer << prefix + " - " << str << endl; } } - for (auto &elem : children_) { + for (auto& elem : children_) { str_buffer << prefix << elem.first << ": " << endl; str_buffer << elem.second.DumpString(prefix + " ") << endl; } @@ -232,6 +231,5 @@ ConfigNode::DumpString(const std::string &prefix) const { return str_buffer.str(); } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/config/ConfigNode.h b/cpp/src/config/ConfigNode.h index d3fabc06550295dfeed2afdd582c1f46d9ef1621..b7bf2c3af14607246e8ffab10afd47ca2ba32c4a 100644 --- a/cpp/src/config/ConfigNode.h +++ b/cpp/src/config/ConfigNode.h @@ -17,11 +17,10 @@ #pragma once -#include -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace server { @@ -30,39 +29,61 @@ typedef std::vector ConfigNodeArr; class ConfigNode { public: - void Combine(const ConfigNode &target); + void + Combine(const ConfigNode& target); - //key/value pair config - void SetValue(const std::string &key, const std::string &value); + // key/value pair config + void + SetValue(const std::string& key, const std::string& value); - std::string GetValue(const std::string ¶m_key, const std::string &default_val = "") const; - bool GetBoolValue(const std::string ¶m_key, bool default_val = false) const; - int32_t GetInt32Value(const std::string ¶m_key, int32_t default_val = 0) const; - int64_t GetInt64Value(const std::string ¶m_key, int64_t default_val = 0) const; - float GetFloatValue(const std::string ¶m_key, float default_val = 0.0) const; - double GetDoubleValue(const std::string ¶m_key, double default_val = 0.0) const; + std::string + GetValue(const std::string& param_key, const std::string& default_val = "") const; + bool + GetBoolValue(const std::string& param_key, bool default_val = false) const; + int32_t + GetInt32Value(const std::string& param_key, int32_t default_val = 0) const; + int64_t + GetInt64Value(const std::string& param_key, int64_t default_val = 0) const; + float + GetFloatValue(const std::string& param_key, float default_val = 0.0) const; + double + GetDoubleValue(const std::string& param_key, double default_val = 0.0) const; - const std::map &GetConfig() const; - void ClearConfig(); + const std::map& + GetConfig() const; + void + ClearConfig(); - //key/object config - void AddChild(const std::string &type_name, const ConfigNode &config); - ConfigNode GetChild(const std::string &type_name) const; - ConfigNode &GetChild(const std::string &type_name); - void GetChildren(ConfigNodeArr &arr) const; + // key/object config + void + AddChild(const std::string& type_name, const ConfigNode& config); + ConfigNode + GetChild(const std::string& type_name) const; + ConfigNode& + GetChild(const std::string& type_name); + void + GetChildren(ConfigNodeArr& arr) const; - const std::map &GetChildren() const; - void ClearChildren(); + const std::map& + GetChildren() const; + void + ClearChildren(); - //key/sequence config - void AddSequenceItem(const std::string &key, const std::string &item); - std::vector GetSequence(const std::string &key) const; + // key/sequence config + void + AddSequenceItem(const std::string& key, const std::string& item); + std::vector + GetSequence(const std::string& key) const; - const std::map > &GetSequences() const; - void ClearSequences(); + const std::map >& + GetSequences() const; + void + ClearSequences(); - void PrintAll(const std::string &prefix = "") const; - std::string DumpString(const std::string &prefix = "") const; + void + PrintAll(const std::string& prefix = "") const; + std::string + DumpString(const std::string& prefix = "") const; private: std::map config_; @@ -70,6 +91,5 @@ class ConfigNode { std::map > sequences_; }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/config/YamlConfigMgr.cpp b/cpp/src/config/YamlConfigMgr.cpp index 1dd8a3d9435554e692f8ea0aeb2da341b43e5b70..71535495442851e1e3ce276cb916428064fb8458 100644 --- a/cpp/src/config/YamlConfigMgr.cpp +++ b/cpp/src/config/YamlConfigMgr.cpp @@ -20,12 +20,11 @@ #include -namespace zilliz { namespace milvus { namespace server { ErrorCode -YamlConfigMgr::LoadConfigFile(const std::string &filename) { +YamlConfigMgr::LoadConfigFile(const std::string& filename) { struct stat directoryStat; int statOK = stat(filename.c_str(), &directoryStat); if (statOK != 0) { @@ -36,8 +35,7 @@ YamlConfigMgr::LoadConfigFile(const std::string &filename) { try { node_ = YAML::LoadFile(filename); LoadConfigNode(node_, config_); - } - catch (YAML::Exception &e) { + } catch (YAML::Exception& e) { SERVER_LOG_ERROR << "Failed to load config file: " << std::string(e.what()); return SERVER_UNEXPECTED_ERROR; } @@ -56,20 +54,18 @@ YamlConfigMgr::DumpString() const { return config_.DumpString(""); } -const ConfigNode & +const ConfigNode& YamlConfigMgr::GetRootNode() const { return config_; } -ConfigNode & +ConfigNode& YamlConfigMgr::GetRootNode() { return config_; } bool -YamlConfigMgr::SetConfigValue(const YAML::Node &node, - const std::string &key, - ConfigNode &config) { +YamlConfigMgr::SetConfigValue(const YAML::Node& node, const std::string& key, ConfigNode& config) { if (node[key].IsDefined()) { config.SetValue(key, node[key].as()); return true; @@ -78,9 +74,7 @@ YamlConfigMgr::SetConfigValue(const YAML::Node &node, } bool -YamlConfigMgr::SetChildConfig(const YAML::Node &node, - const std::string &child_name, - ConfigNode &config) { +YamlConfigMgr::SetChildConfig(const YAML::Node& node, const std::string& child_name, ConfigNode& config) { if (node[child_name].IsDefined()) { ConfigNode sub_config; LoadConfigNode(node[child_name], sub_config); @@ -91,9 +85,7 @@ YamlConfigMgr::SetChildConfig(const YAML::Node &node, } bool -YamlConfigMgr::SetSequence(const YAML::Node &node, - const std::string &child_name, - ConfigNode &config) { +YamlConfigMgr::SetSequence(const YAML::Node& node, const std::string& child_name, ConfigNode& config) { if (node[child_name].IsDefined()) { size_t cnt = node[child_name].size(); for (size_t i = 0; i < cnt; i++) { @@ -105,7 +97,7 @@ YamlConfigMgr::SetSequence(const YAML::Node &node, } void -YamlConfigMgr::LoadConfigNode(const YAML::Node &node, ConfigNode &config) { +YamlConfigMgr::LoadConfigNode(const YAML::Node& node, ConfigNode& config) { std::string key; for (YAML::const_iterator it = node.begin(); it != node.end(); ++it) { if (!it->first.IsNull()) { @@ -121,6 +113,5 @@ YamlConfigMgr::LoadConfigNode(const YAML::Node &node, ConfigNode &config) { } } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/config/YamlConfigMgr.h b/cpp/src/config/YamlConfigMgr.h index 6b84943bf808d81e4de875eea5fb5af92e625cfc..1c68bc883fd814ccd491ee3b3ed0f89654c1185f 100644 --- a/cpp/src/config/YamlConfigMgr.h +++ b/cpp/src/config/YamlConfigMgr.h @@ -21,43 +21,43 @@ #include "ConfigNode.h" #include "utils/Error.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace server { class YamlConfigMgr : public ConfigMgr { public: - virtual ErrorCode LoadConfigFile(const std::string &filename); - virtual void Print() const; - virtual std::string DumpString() const; - - virtual const ConfigNode &GetRootNode() const; - virtual ConfigNode &GetRootNode(); + virtual ErrorCode + LoadConfigFile(const std::string& filename); + virtual void + Print() const; + virtual std::string + DumpString() const; + + virtual const ConfigNode& + GetRootNode() const; + virtual ConfigNode& + GetRootNode(); private: - bool SetConfigValue(const YAML::Node &node, - const std::string &key, - ConfigNode &config); + bool + SetConfigValue(const YAML::Node& node, const std::string& key, ConfigNode& config); - bool SetChildConfig(const YAML::Node &node, - const std::string &name, - ConfigNode &config); + bool + SetChildConfig(const YAML::Node& node, const std::string& child_name, ConfigNode& config); bool - SetSequence(const YAML::Node &node, - const std::string &child_name, - ConfigNode &config); + SetSequence(const YAML::Node& node, const std::string& child_name, ConfigNode& config); - void LoadConfigNode(const YAML::Node &node, ConfigNode &config); + void + LoadConfigNode(const YAML::Node& node, ConfigNode& config); private: YAML::Node node_; ConfigNode config_; }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/core/CMakeLists.txt b/cpp/src/core/CMakeLists.txt index bef3f5e2f685be12023f87fa3b80d07121364e59..9125f6ea2c5fc2bdedd57533eaeaab341d047561 100644 --- a/cpp/src/core/CMakeLists.txt +++ b/cpp/src/core/CMakeLists.txt @@ -46,9 +46,11 @@ if(NOT CMAKE_BUILD_TYPE) endif(NOT CMAKE_BUILD_TYPE) if(CMAKE_BUILD_TYPE STREQUAL "Release") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -fopenmp") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3") else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -fopenmp") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g") endif() MESSAGE(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS}) @@ -93,7 +95,7 @@ endif() set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE) if(BUILD_UNIT_TEST STREQUAL "ON") - add_subdirectory(test) + add_subdirectory(unittest) endif() config_summary() diff --git a/cpp/src/core/build.sh b/cpp/src/core/build.sh old mode 100755 new mode 100644 diff --git a/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.cpp b/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.cpp index c07cb16fca566630b56b06294cbe86e79b270543..38227f43ab858ed652ab2b377102121304fdfc0e 100644 --- a/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.cpp +++ b/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.cpp @@ -15,42 +15,39 @@ // specific language governing permissions and limitations // under the License. +#include "knowhere/adapter/ArrowAdapter.h" -#include "ArrowAdapter.h" - -namespace zilliz { namespace knowhere { ArrayPtr -CopyArray(const ArrayPtr &origin) { +CopyArray(const ArrayPtr& origin) { ArrayPtr copy = nullptr; auto copy_data = origin->data()->Copy(); switch (origin->type_id()) { -#define DEFINE_TYPE(type, clazz) \ - case arrow::Type::type: { \ - copy = std::make_shared(copy_data); \ - } +#define DEFINE_TYPE(type, clazz) \ + case arrow::Type::type: { \ + copy = std::make_shared(copy_data); \ + } DEFINE_TYPE(BOOL, BooleanArray) DEFINE_TYPE(BINARY, BinaryArray) DEFINE_TYPE(FIXED_SIZE_BINARY, FixedSizeBinaryArray) DEFINE_TYPE(DECIMAL, Decimal128Array) DEFINE_TYPE(FLOAT, NumericArray) DEFINE_TYPE(INT64, NumericArray) - default:break; + default: + break; } return copy; } SchemaPtr -CopySchema(const SchemaPtr &origin) { +CopySchema(const SchemaPtr& origin) { std::vector> fields; - for (auto &field : origin->fields()) { - auto copy = std::make_shared(field->name(), field->type(),field->nullable(), nullptr); + for (auto& field : origin->fields()) { + auto copy = std::make_shared(field->name(), field->type(), field->nullable(), nullptr); fields.emplace_back(copy); } return std::make_shared(std::move(fields)); } - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.h b/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.h index d19f8f3ae5a46267b6d3f9de7b78de73c1d43b70..75580cd16354dde245749180548d6acdb6451479 100644 --- a/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.h +++ b/cpp/src/core/knowhere/knowhere/adapter/ArrowAdapter.h @@ -15,22 +15,20 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include +#include +#include #include "knowhere/common/Array.h" - -namespace zilliz { namespace knowhere { ArrayPtr -CopyArray(const ArrayPtr &origin); +CopyArray(const ArrayPtr& origin); SchemaPtr -CopySchema(const SchemaPtr &origin); +CopySchema(const SchemaPtr& origin); -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.cpp b/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.cpp index 056f4e2cad2aa523b6c9e07a6ab43e9856ea9b79..b4c3910a01f5281e7770e933032162dc30c84255 100644 --- a/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.cpp +++ b/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.cpp @@ -15,36 +15,30 @@ // specific language governing permissions and limitations // under the License. +#include "knowhere/adapter/SptagAdapter.h" +#include "knowhere/adapter/Structure.h" #include "knowhere/index/vector_index/helpers/Definitions.h" -#include "SptagAdapter.h" -#include "Structure.h" - -namespace zilliz { namespace knowhere { - std::shared_ptr -ConvertToMetadataSet(const DatasetPtr &dataset) { +ConvertToMetadataSet(const DatasetPtr& dataset) { auto array = dataset->array()[0]; auto elems = array->length(); auto p_data = array->data()->GetValues(1, 0); - auto p_offset = (int64_t *) malloc(sizeof(int64_t) * elems); - for (auto i = 0; i <= elems; ++i) - p_offset[i] = i * 8; - - std::shared_ptr metaset(new SPTAG::MemMetadataSet( - SPTAG::ByteArray((std::uint8_t *) p_data, elems * sizeof(int64_t), false), - SPTAG::ByteArray((std::uint8_t *) p_offset, elems * sizeof(int64_t), true), - elems)); + auto p_offset = (int64_t*)malloc(sizeof(int64_t) * elems); + for (auto i = 0; i <= elems; ++i) p_offset[i] = i * 8; + std::shared_ptr metaset( + new SPTAG::MemMetadataSet(SPTAG::ByteArray((std::uint8_t*)p_data, elems * sizeof(int64_t), false), + SPTAG::ByteArray((std::uint8_t*)p_offset, elems * sizeof(int64_t), true), elems)); return metaset; } std::shared_ptr -ConvertToVectorSet(const DatasetPtr &dataset) { +ConvertToVectorSet(const DatasetPtr& dataset) { auto tensor = dataset->tensor()[0]; auto p_data = tensor->raw_mutable_data(); @@ -54,18 +48,16 @@ ConvertToVectorSet(const DatasetPtr &dataset) { SPTAG::ByteArray byte_array(p_data, num_bytes, false); - auto vectorset = std::make_shared(byte_array, - SPTAG::VectorValueType::Float, - dimension, - rows); + auto vectorset = + std::make_shared(byte_array, SPTAG::VectorValueType::Float, dimension, rows); return vectorset; } std::vector -ConvertToQueryResult(const DatasetPtr &dataset, const Config &config) { +ConvertToQueryResult(const DatasetPtr& dataset, const Config& config) { auto tensor = dataset->tensor()[0]; - auto p_data = (float *) tensor->raw_mutable_data(); + auto p_data = (float*)tensor->raw_mutable_data(); auto dimension = tensor->shape()[1]; auto rows = tensor->shape()[0]; @@ -82,8 +74,8 @@ ConvertToDataset(std::vector query_results) { auto k = query_results[0].GetResultNum(); auto elems = query_results.size() * k; - auto p_id = (int64_t *) malloc(sizeof(int64_t) * elems); - auto p_dist = (float *) malloc(sizeof(float) * elems); + auto p_id = (int64_t*)malloc(sizeof(int64_t) * elems); + auto p_dist = (float*)malloc(sizeof(float) * elems); // TODO: throw if malloc failed. #pragma omp parallel for @@ -91,14 +83,14 @@ ConvertToDataset(std::vector query_results) { auto results = query_results[i].GetResults(); auto num_result = query_results[i].GetResultNum(); for (auto j = 0; j < num_result; ++j) { -// p_id[i * k + j] = results[j].VID; - p_id[i * k + j] = *(int64_t *) query_results[i].GetMetadata(j).Data(); + // p_id[i * k + j] = results[j].VID; + p_id[i * k + j] = *(int64_t*)query_results[i].GetMetadata(j).Data(); p_dist[i * k + j] = results[j].Dist; } } - auto id_buf = MakeMutableBufferSmart((uint8_t *) p_id, sizeof(int64_t) * elems); - auto dist_buf = MakeMutableBufferSmart((uint8_t *) p_dist, sizeof(float) * elems); + auto id_buf = MakeMutableBufferSmart((uint8_t*)p_id, sizeof(int64_t) * elems); + auto dist_buf = MakeMutableBufferSmart((uint8_t*)p_dist, sizeof(float) * elems); // TODO: magic std::vector id_bufs{nullptr, id_buf}; @@ -109,11 +101,11 @@ ConvertToDataset(std::vector query_results) { auto id_array_data = arrow::ArrayData::Make(int64_type, elems, id_bufs); auto dist_array_data = arrow::ArrayData::Make(float_type, elems, dist_bufs); -// auto id_array_data = std::make_shared(int64_type, sizeof(int64_t) * elems, id_bufs); -// auto dist_array_data = std::make_shared(float_type, sizeof(float) * elems, dist_bufs); + // auto id_array_data = std::make_shared(int64_type, sizeof(int64_t) * elems, id_bufs); + // auto dist_array_data = std::make_shared(float_type, sizeof(float) * elems, dist_bufs); -// auto ids = ConstructInt64Array((uint8_t*)p_id, sizeof(int64_t) * elems); -// auto dists = ConstructFloatArray((uint8_t*)p_dist, sizeof(float) * elems); + // auto ids = ConstructInt64Array((uint8_t*)p_id, sizeof(int64_t) * elems); + // auto dists = ConstructFloatArray((uint8_t*)p_dist, sizeof(float) * elems); auto ids = std::make_shared>(id_array_data); auto dists = std::make_shared>(dist_array_data); @@ -127,5 +119,4 @@ ConvertToDataset(std::vector query_results) { return std::make_shared(array, schema); } -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.h b/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.h index f47ffdc3b5814d2e0f97008a08983b00bb62cb88..9f924975623f8c775ad7e54f71ad2c3d18b2d082 100644 --- a/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.h +++ b/cpp/src/core/knowhere/knowhere/adapter/SptagAdapter.h @@ -15,29 +15,26 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include - #include +#include +#include #include "knowhere/common/Dataset.h" -namespace zilliz { namespace knowhere { std::shared_ptr -ConvertToVectorSet(const DatasetPtr &dataset); +ConvertToVectorSet(const DatasetPtr& dataset); std::shared_ptr -ConvertToMetadataSet(const DatasetPtr &dataset); +ConvertToMetadataSet(const DatasetPtr& dataset); std::vector -ConvertToQueryResult(const DatasetPtr &dataset, const Config &config); +ConvertToQueryResult(const DatasetPtr& dataset, const Config& config); DatasetPtr ConvertToDataset(std::vector query_results); -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/adapter/Structure.cpp b/cpp/src/core/knowhere/knowhere/adapter/Structure.cpp index 18833b5d36d7e43ac32d9c5998e97d31443c288b..44b068c7929461a4da2399a813392e8b9d00a57c 100644 --- a/cpp/src/core/knowhere/knowhere/adapter/Structure.cpp +++ b/cpp/src/core/knowhere/knowhere/adapter/Structure.cpp @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. +#include "knowhere/adapter/Structure.h" -#include "Structure.h" +#include +#include - -namespace zilliz { namespace knowhere { ArrayPtr -ConstructInt64ArraySmart(uint8_t *data, int64_t size) { +ConstructInt64ArraySmart(uint8_t* data, int64_t size) { // TODO: magic std::vector id_buf{nullptr, MakeMutableBufferSmart(data, size)}; auto type = std::make_shared(); @@ -32,7 +32,7 @@ ConstructInt64ArraySmart(uint8_t *data, int64_t size) { } ArrayPtr -ConstructFloatArraySmart(uint8_t *data, int64_t size) { +ConstructFloatArraySmart(uint8_t* data, int64_t size) { // TODO: magic std::vector id_buf{nullptr, MakeMutableBufferSmart(data, size)}; auto type = std::make_shared(); @@ -41,14 +41,14 @@ ConstructFloatArraySmart(uint8_t *data, int64_t size) { } TensorPtr -ConstructFloatTensorSmart(uint8_t *data, int64_t size, std::vector shape) { +ConstructFloatTensorSmart(uint8_t* data, int64_t size, std::vector shape) { auto buffer = MakeMutableBufferSmart(data, size); auto float_type = std::make_shared(); return std::make_shared(float_type, buffer, shape); } ArrayPtr -ConstructInt64Array(uint8_t *data, int64_t size) { +ConstructInt64Array(uint8_t* data, int64_t size) { // TODO: magic std::vector id_buf{nullptr, MakeMutableBuffer(data, size)}; auto type = std::make_shared(); @@ -57,7 +57,7 @@ ConstructInt64Array(uint8_t *data, int64_t size) { } ArrayPtr -ConstructFloatArray(uint8_t *data, int64_t size) { +ConstructFloatArray(uint8_t* data, int64_t size) { // TODO: magic std::vector id_buf{nullptr, MakeMutableBuffer(data, size)}; auto type = std::make_shared(); @@ -66,23 +66,22 @@ ConstructFloatArray(uint8_t *data, int64_t size) { } TensorPtr -ConstructFloatTensor(uint8_t *data, int64_t size, std::vector shape) { +ConstructFloatTensor(uint8_t* data, int64_t size, std::vector shape) { auto buffer = MakeMutableBuffer(data, size); auto float_type = std::make_shared(); return std::make_shared(float_type, buffer, shape); } FieldPtr -ConstructInt64Field(const std::string &name) { +ConstructInt64Field(const std::string& name) { auto type = std::make_shared(); return std::make_shared(name, type); } - FieldPtr -ConstructFloatField(const std::string &name) { +ConstructFloatField(const std::string& name) { auto type = std::make_shared(); return std::make_shared(name, type); } -} -} + +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/adapter/Structure.h b/cpp/src/core/knowhere/knowhere/adapter/Structure.h index 7539dce2dae90c4d1bd5c2e1ad2f84be3b97e6d5..6bde9ddfe6fcb675bbf447224d17f21965001a81 100644 --- a/cpp/src/core/knowhere/knowhere/adapter/Structure.h +++ b/cpp/src/core/knowhere/knowhere/adapter/Structure.h @@ -15,40 +15,38 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include -#include "knowhere/common/Dataset.h" +#include +#include +#include "knowhere/common/Dataset.h" -namespace zilliz { namespace knowhere { extern ArrayPtr -ConstructInt64ArraySmart(uint8_t *data, int64_t size); +ConstructInt64ArraySmart(uint8_t* data, int64_t size); extern ArrayPtr -ConstructFloatArraySmart(uint8_t *data, int64_t size); +ConstructFloatArraySmart(uint8_t* data, int64_t size); extern TensorPtr -ConstructFloatTensorSmart(uint8_t *data, int64_t size, std::vector shape); +ConstructFloatTensorSmart(uint8_t* data, int64_t size, std::vector shape); extern ArrayPtr -ConstructInt64Array(uint8_t *data, int64_t size); +ConstructInt64Array(uint8_t* data, int64_t size); extern ArrayPtr -ConstructFloatArray(uint8_t *data, int64_t size); +ConstructFloatArray(uint8_t* data, int64_t size); extern TensorPtr -ConstructFloatTensor(uint8_t *data, int64_t size, std::vector shape); +ConstructFloatTensor(uint8_t* data, int64_t size, std::vector shape); extern FieldPtr -ConstructInt64Field(const std::string &name); +ConstructInt64Field(const std::string& name); extern FieldPtr -ConstructFloatField(const std::string &name); - +ConstructFloatField(const std::string& name); -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/adapter/VectorAdapter.h b/cpp/src/core/knowhere/knowhere/adapter/VectorAdapter.h index bdace40a745a302ae5773c11a743573eafe1ae5b..2b16227bb302781ef420a9174a9c746e8d5e1684 100644 --- a/cpp/src/core/knowhere/knowhere/adapter/VectorAdapter.h +++ b/cpp/src/core/knowhere/knowhere/adapter/VectorAdapter.h @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. - #pragma once -namespace zilliz { namespace knowhere { -#define GETTENSOR(dataset) \ - auto tensor = dataset->tensor()[0]; \ - auto p_data = tensor->raw_data(); \ - auto dim = tensor->shape()[1]; \ - auto rows = tensor->shape()[0]; \ - +#define GETTENSOR(dataset) \ + auto tensor = dataset->tensor()[0]; \ + auto p_data = tensor->raw_data(); \ + auto dim = tensor->shape()[1]; \ + auto rows = tensor->shape()[0]; -} -} \ No newline at end of file +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Array.h b/cpp/src/core/knowhere/knowhere/common/Array.h index 94a5470029b7fbd5e44e4d5951e39e573fdb1382..71ad78b79beaff1808bac55194cbc6a23378a4b8 100644 --- a/cpp/src/core/knowhere/knowhere/common/Array.h +++ b/cpp/src/core/knowhere/knowhere/common/Array.h @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include +#include #include "Schema.h" - -namespace zilliz { namespace knowhere { using ArrayData = arrow::ArrayData; @@ -35,9 +33,9 @@ using ArrayPtr = std::shared_ptr; using BooleanArray = arrow::BooleanArray; using BooleanArrayPtr = std::shared_ptr; -template +template using NumericArray = arrow::NumericArray; -template +template using NumericArrayPtr = std::shared_ptr>; using BinaryArray = arrow::BinaryArray; @@ -49,6 +47,4 @@ using FixedSizeBinaryArrayPtr = std::shared_ptr; using Decimal128Array = arrow::Decimal128Array; using Decimal128ArrayPtr = std::shared_ptr; - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/BinarySet.h b/cpp/src/core/knowhere/knowhere/common/BinarySet.h index 509db68e2d25a7a19488032e41c8b6ba2e871d3f..6e6d53000e4a96ee326d0aa720c7d7232ca126d6 100644 --- a/cpp/src/core/knowhere/knowhere/common/BinarySet.h +++ b/cpp/src/core/knowhere/knowhere/common/BinarySet.h @@ -15,21 +15,18 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include +#include #include +#include #include -#include #include "Id.h" - -namespace zilliz { namespace knowhere { - struct Binary { ID id; std::shared_ptr data; @@ -37,29 +34,28 @@ struct Binary { }; using BinaryPtr = std::shared_ptr; - class BinarySet { public: BinaryPtr - GetByName(const std::string &name) const { + GetByName(const std::string& name) const { return binary_map_.at(name); } void - Append(const std::string &name, BinaryPtr binary) { + Append(const std::string& name, BinaryPtr binary) { binary_map_[name] = std::move(binary); } void - Append(const std::string &name, std::shared_ptr data, int64_t size) { + Append(const std::string& name, std::shared_ptr data, int64_t size) { auto binary = std::make_shared(); binary->data = data; binary->size = size; binary_map_[name] = std::move(binary); } - //void - //Append(const std::string &name, void *data, int64_t size, ID id) { + // void + // Append(const std::string &name, void *data, int64_t size, ID id) { // Binary binary; // binary.data = data; // binary.size = size; @@ -67,7 +63,8 @@ class BinarySet { // binary_map_[name] = binary; //} - void clear() { + void + clear() { binary_map_.clear(); } @@ -75,6 +72,4 @@ class BinarySet { std::map binary_map_; }; - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Buffer.h b/cpp/src/core/knowhere/knowhere/common/Buffer.h index 4468e6ec01cb5c1ceff47ccb75ed73ab8a32a843..f9e15d95bdd31a260f80a335f53d82371c748794 100644 --- a/cpp/src/core/knowhere/knowhere/common/Buffer.h +++ b/cpp/src/core/knowhere/knowhere/common/Buffer.h @@ -15,15 +15,12 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include #include - -namespace zilliz { namespace knowhere { using Buffer = arrow::Buffer; @@ -34,31 +31,31 @@ using MutableBufferPtr = std::shared_ptr; namespace internal { struct BufferDeleter { - void operator()(Buffer *buffer) { - free((void *) buffer->data()); + void + operator()(Buffer* buffer) { + free((void*)buffer->data()); } }; +} // namespace internal -} inline BufferPtr -MakeBufferSmart(uint8_t *data, const int64_t size) { +MakeBufferSmart(uint8_t* data, const int64_t size) { return BufferPtr(new Buffer(data, size), internal::BufferDeleter()); } inline MutableBufferPtr -MakeMutableBufferSmart(uint8_t *data, const int64_t size) { +MakeMutableBufferSmart(uint8_t* data, const int64_t size) { return MutableBufferPtr(new MutableBuffer(data, size), internal::BufferDeleter()); } inline BufferPtr -MakeBuffer(uint8_t *data, const int64_t size) { +MakeBuffer(uint8_t* data, const int64_t size) { return std::make_shared(data, size); } inline MutableBufferPtr -MakeMutableBuffer(uint8_t *data, const int64_t size) { +MakeMutableBuffer(uint8_t* data, const int64_t size) { return std::make_shared(data, size); } -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Config.h b/cpp/src/core/knowhere/knowhere/common/Config.h index 208641d666e8559a0bfe6556113f4b216ed7834e..6191ecb771f61168dca171982fc76f908a54def5 100644 --- a/cpp/src/core/knowhere/knowhere/common/Config.h +++ b/cpp/src/core/knowhere/knowhere/common/Config.h @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include -namespace zilliz { namespace knowhere { enum class METRICTYPE { @@ -42,18 +40,17 @@ struct Cfg { int64_t gpu_id = DEFAULT_GPUID; int64_t d = DEFAULT_DIM; - Cfg(const int64_t &dim, - const int64_t &k, - const int64_t &gpu_id, - METRICTYPE type) - : d(dim), k(k), gpu_id(gpu_id), metric_type(type) {} + Cfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, METRICTYPE type) + : metric_type(type), k(k), gpu_id(gpu_id), d(dim) { + } Cfg() = default; virtual bool - CheckValid(){}; + CheckValid() { + return true; + } }; using Config = std::shared_ptr; -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Dataset.h b/cpp/src/core/knowhere/knowhere/common/Dataset.h index c2d74a59e57c7e1ab4270570df45c4e87cae4343..1331239dd6bd7b0856a0846a0d74b3b7baf20c90 100644 --- a/cpp/src/core/knowhere/knowhere/common/Dataset.h +++ b/cpp/src/core/knowhere/knowhere/common/Dataset.h @@ -15,21 +15,19 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include #include +#include +#include #include "Array.h" #include "Buffer.h" -#include "Tensor.h" -#include "Schema.h" #include "Config.h" +#include "Schema.h" +#include "Tensor.h" #include "knowhere/adapter/ArrowAdapter.h" - -namespace zilliz { namespace knowhere { class Dataset; @@ -40,34 +38,38 @@ class Dataset { public: Dataset() = default; - Dataset(std::vector &&array, SchemaPtr array_schema, - std::vector &&tensor, SchemaPtr tensor_schema) + Dataset(std::vector&& array, SchemaPtr array_schema, std::vector&& tensor, + SchemaPtr tensor_schema) : array_(std::move(array)), array_schema_(std::move(array_schema)), tensor_(std::move(tensor)), - tensor_schema_(std::move(tensor_schema)) {} + tensor_schema_(std::move(tensor_schema)) { + } Dataset(std::vector array, SchemaPtr array_schema) - : array_(std::move(array)), array_schema_(std::move(array_schema)) {} + : array_(std::move(array)), array_schema_(std::move(array_schema)) { + } Dataset(std::vector tensor, SchemaPtr tensor_schema) - : tensor_(std::move(tensor)), tensor_schema_(std::move(tensor_schema)) {} + : tensor_(std::move(tensor)), tensor_schema_(std::move(tensor_schema)) { + } - Dataset(const Dataset &) = delete; - Dataset &operator=(const Dataset &) = delete; + Dataset(const Dataset&) = delete; + Dataset& + operator=(const Dataset&) = delete; DatasetPtr Clone() { auto dataset = std::make_shared(); std::vector clone_array; - for (auto &array : array_) { + for (auto& array : array_) { clone_array.emplace_back(CopyArray(array)); } dataset->set_array(clone_array); std::vector clone_tensor; - for (auto &tensor : tensor_) { + for (auto& tensor : tensor_) { auto buffer = tensor->data(); std::shared_ptr copy_buffer; // TODO: checkout copy success; @@ -86,16 +88,20 @@ class Dataset { } public: - const std::vector & - array() const { return array_; } + const std::vector& + array() const { + return array_; + } void set_array(std::vector array) { array_ = std::move(array); } - const std::vector & - tensor() const { return tensor_; } + const std::vector& + tensor() const { + return tensor_; + } void set_tensor(std::vector tensor) { @@ -103,7 +109,9 @@ class Dataset { } SchemaConstPtr - array_schema() const { return array_schema_; } + array_schema() const { + return array_schema_; + } void set_array_schema(SchemaPtr array_schema) { @@ -111,31 +119,31 @@ class Dataset { } SchemaConstPtr - tensor_schema() const { return tensor_schema_; } + tensor_schema() const { + return tensor_schema_; + } void set_tensor_schema(SchemaPtr tensor_schema) { tensor_schema_ = std::move(tensor_schema); } - //const Config & - //meta() const { return meta_; } + // const Config & + // meta() const { return meta_; } - //void - //set_meta(Config meta) { + // void + // set_meta(Config meta) { // meta_ = std::move(meta); //} private: - SchemaPtr array_schema_; - SchemaPtr tensor_schema_; std::vector array_; + SchemaPtr array_schema_; std::vector tensor_; - //Config meta_; + SchemaPtr tensor_schema_; + // Config meta_; }; using DatasetPtr = std::shared_ptr; - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Exception.cpp b/cpp/src/core/knowhere/knowhere/common/Exception.cpp index a77e85ee41a9a856ac668aea3e6d56e44bb8188e..4fef9529ac90c3733b1f903c5ec68266a846cdf1 100644 --- a/cpp/src/core/knowhere/knowhere/common/Exception.cpp +++ b/cpp/src/core/knowhere/knowhere/common/Exception.cpp @@ -15,41 +15,35 @@ // specific language governing permissions and limitations // under the License. - #include -#include "Exception.h" #include "Log.h" +#include "knowhere/common/Exception.h" -namespace zilliz { namespace knowhere { +KnowhereException::KnowhereException(const std::string& msg) : msg(msg) { +} -KnowhereException::KnowhereException(const std::string &msg):msg(msg) {} - -KnowhereException::KnowhereException(const std::string &m, const char *funcName, const char *file, int line) { +KnowhereException::KnowhereException(const std::string& m, const char* funcName, const char* file, int line) { #ifdef DEBUG - int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", - funcName, file, line, m.c_str()); + int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, file, line, m.c_str()); msg.resize(size + 1); - snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", - funcName, file, line, m.c_str()); + snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", funcName, file, line, m.c_str()); #else std::string file_path(file); auto const pos = file_path.find_last_of('/'); - auto filename = file_path.substr(pos+1).c_str(); + auto filename = file_path.substr(pos + 1).c_str(); - int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", - funcName, filename, line, m.c_str()); + int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str()); msg.resize(size + 1); - snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", - funcName, filename, line, m.c_str()); + snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", funcName, filename, line, m.c_str()); #endif } -const char *KnowhereException::what() const noexcept { +const char* +KnowhereException::what() const noexcept { return msg.c_str(); } -} -} \ No newline at end of file +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Exception.h b/cpp/src/core/knowhere/knowhere/common/Exception.h index d357f0501e5e68727ec9fe26e7c56334a59a5cf0..9ffc7838cfdf3d7de794ac9b91ad2a8c88af2b4a 100644 --- a/cpp/src/core/knowhere/knowhere/common/Exception.h +++ b/cpp/src/core/knowhere/knowhere/common/Exception.h @@ -15,46 +15,39 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include #include - -namespace zilliz { namespace knowhere { class KnowhereException : public std::exception { public: - explicit KnowhereException(const std::string &msg); + explicit KnowhereException(const std::string& msg); - KnowhereException(const std::string &msg, const char *funName, - const char *file, int line); + KnowhereException(const std::string& msg, const char* funName, const char* file, int line); - const char *what() const noexcept override; + const char* + what() const noexcept override; std::string msg; }; +#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what()) -#define KNOHWERE_ERROR_MSG(MSG)\ -printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what()) - -#define KNOWHERE_THROW_MSG(MSG)\ -do {\ - throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__);\ -} while (false) - -#define KNOHERE_THROW_FORMAT(FMT, ...)\ - do { \ - std::string __s;\ - int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__);\ - __s.resize(__size + 1);\ - snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__);\ - throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__);\ - } while (false) +#define KNOWHERE_THROW_MSG(MSG) \ + do { \ + throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + } while (false) +#define KNOHERE_THROW_FORMAT(FMT, ...) \ + do { \ + std::string __s; \ + int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \ + __s.resize(__size + 1); \ + snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \ + throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + } while (false) -} -} \ No newline at end of file +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Id.h b/cpp/src/core/knowhere/knowhere/common/Id.h index 87e1467c4487a8f03cc40cdeb8fb288777996e29..1c8216168452faecfe1eb9038e63eb3d9eb14f4c 100644 --- a/cpp/src/core/knowhere/knowhere/common/Id.h +++ b/cpp/src/core/knowhere/knowhere/common/Id.h @@ -15,30 +15,27 @@ // specific language governing permissions and limitations // under the License. - #pragma once -//#include "zcommon/id/id.h" -//using ID = zilliz::common::ID; - #include #include -namespace zilliz { namespace knowhere { - - class ID { public: constexpr static int64_t kIDSize = 20; public: - const int32_t * - data() const { return content_; } + const int32_t* + data() const { + return content_; + } - int32_t * - mutable_data() { return content_; } + int32_t* + mutable_data() { + return content_; + } bool IsValid() const; @@ -47,14 +44,13 @@ class ID { ToString() const; bool - operator==(const ID &that) const; + operator==(const ID& that) const; bool - operator<(const ID &that) const; + operator<(const ID& that) const; protected: int32_t content_[5] = {}; }; -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Log.h b/cpp/src/core/knowhere/knowhere/common/Log.h index 1e390b3c1e6c2e1868f72614d81dbfbc65198ee1..222d03d73e65ce2d30c0a2589c223b4d998bd269 100644 --- a/cpp/src/core/knowhere/knowhere/common/Log.h +++ b/cpp/src/core/knowhere/knowhere/common/Log.h @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "utils/easylogging++.h" -namespace zilliz { namespace knowhere { #define KNOWHERE_DOMAIN_NAME "[KNOWHERE] " @@ -33,5 +31,4 @@ namespace knowhere { #define KNOWHERE_LOG_ERROR LOG(ERROR) << KNOWHERE_DOMAIN_NAME #define KNOWHERE_LOG_FATAL LOG(FATAL) << KNOWHERE_DOMAIN_NAME -} // namespace knowhere -} // namespace zilliz \ No newline at end of file +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Schema.h b/cpp/src/core/knowhere/knowhere/common/Schema.h index b43b7eb8753400f236964121a37d047b616436ef..c90bac757223d7646dd60b5b7a894e852ca136ab 100644 --- a/cpp/src/core/knowhere/knowhere/common/Schema.h +++ b/cpp/src/core/knowhere/knowhere/common/Schema.h @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include #include - -namespace zilliz { namespace knowhere { - using DataType = arrow::DataType; using Field = arrow::Field; using FieldPtr = std::shared_ptr; @@ -34,7 +30,4 @@ using Schema = arrow::Schema; using SchemaPtr = std::shared_ptr; using SchemaConstPtr = std::shared_ptr; - - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Tensor.h b/cpp/src/core/knowhere/knowhere/common/Tensor.h index 42e86dc4d1a8c41bc6a1b6b085d286b927981f05..ff957319e56ac605129d4b1077c4f25c9f1f5722 100644 --- a/cpp/src/core/knowhere/knowhere/common/Tensor.h +++ b/cpp/src/core/knowhere/knowhere/common/Tensor.h @@ -15,21 +15,15 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include #include - -namespace zilliz { namespace knowhere { - using Tensor = arrow::Tensor; using TensorPtr = std::shared_ptr; - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Timer.cpp b/cpp/src/core/knowhere/knowhere/common/Timer.cpp index 7f2cb715149b78cf9fc0e5791791ac1abf8723a6..fefe1e1705bb32a7b739af87ad3e1e5360ebdf31 100644 --- a/cpp/src/core/knowhere/knowhere/common/Timer.cpp +++ b/cpp/src/core/knowhere/knowhere/common/Timer.cpp @@ -15,18 +15,13 @@ // specific language governing permissions and limitations // under the License. +#include // TODO(linxj): using Log instead -#include // TODO(linxj): using Log instead +#include "knowhere/common/Timer.h" -#include "Timer.h" - -namespace zilliz { namespace knowhere { -TimeRecorder::TimeRecorder(const std::string &header, - int64_t log_level) : - header_(header), - log_level_(log_level) { +TimeRecorder::TimeRecorder(const std::string& header, int64_t log_level) : header_(header), log_level_(log_level) { start_ = last_ = stdclock::now(); } @@ -42,9 +37,10 @@ TimeRecorder::GetTimeSpanStr(double span) { } void -TimeRecorder::PrintTimeRecord(const std::string &msg, double span) { +TimeRecorder::PrintTimeRecord(const std::string& msg, double span) { std::string str_log; - if (!header_.empty()) str_log += header_ + ": "; + if (!header_.empty()) + str_log += header_ + ": "; str_log += msg; str_log += " ("; str_log += TimeRecorder::GetTimeSpanStr(span); @@ -55,35 +51,35 @@ TimeRecorder::PrintTimeRecord(const std::string &msg, double span) { std::cout << str_log << std::endl; break; } - //case 1: { - // SERVER_LOG_DEBUG << str_log; - // break; - //} - //case 2: { - // SERVER_LOG_INFO << str_log; - // break; - //} - //case 3: { - // SERVER_LOG_WARNING << str_log; - // break; - //} - //case 4: { - // SERVER_LOG_ERROR << str_log; - // break; - //} - //case 5: { - // SERVER_LOG_FATAL << str_log; - // break; - //} - //default: { - // SERVER_LOG_INFO << str_log; - // break; - //} + // case 1: { + // SERVER_LOG_DEBUG << str_log; + // break; + //} + // case 2: { + // SERVER_LOG_INFO << str_log; + // break; + //} + // case 3: { + // SERVER_LOG_WARNING << str_log; + // break; + //} + // case 4: { + // SERVER_LOG_ERROR << str_log; + // break; + //} + // case 5: { + // SERVER_LOG_FATAL << str_log; + // break; + //} + // default: { + // SERVER_LOG_INFO << str_log; + // break; + //} } } double -TimeRecorder::RecordSection(const std::string &msg) { +TimeRecorder::RecordSection(const std::string& msg) { stdclock::time_point curr = stdclock::now(); double span = (std::chrono::duration(curr - last_)).count(); last_ = curr; @@ -93,7 +89,7 @@ TimeRecorder::RecordSection(const std::string &msg) { } double -TimeRecorder::ElapseFromBegin(const std::string &msg) { +TimeRecorder::ElapseFromBegin(const std::string& msg) { stdclock::time_point curr = stdclock::now(); double span = (std::chrono::duration(curr - start_)).count(); @@ -101,5 +97,4 @@ TimeRecorder::ElapseFromBegin(const std::string &msg) { return span; } -} -} \ No newline at end of file +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/common/Timer.h b/cpp/src/core/knowhere/knowhere/common/Timer.h index 8ecd60df994e1f64ac321c813949bc15e10a014f..7e01e88aa9a9c9f640979afc0ec994350b3d96de 100644 --- a/cpp/src/core/knowhere/knowhere/common/Timer.h +++ b/cpp/src/core/knowhere/knowhere/common/Timer.h @@ -15,32 +15,33 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include #include +#include -namespace zilliz { namespace knowhere { class TimeRecorder { using stdclock = std::chrono::high_resolution_clock; public: - TimeRecorder(const std::string &header, - int64_t log_level = 0); + explicit TimeRecorder(const std::string& header, int64_t log_level = 0); - ~TimeRecorder();//trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5 + ~TimeRecorder(); // trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5 - double RecordSection(const std::string &msg); + double + RecordSection(const std::string& msg); - double ElapseFromBegin(const std::string &msg); + double + ElapseFromBegin(const std::string& msg); - static std::string GetTimeSpanStr(double span); + static std::string + GetTimeSpanStr(double span); private: - void PrintTimeRecord(const std::string &msg, double span); + void + PrintTimeRecord(const std::string& msg, double span); private: std::string header_; @@ -49,5 +50,4 @@ class TimeRecorder { int64_t log_level_; }; -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/Index.h b/cpp/src/core/knowhere/knowhere/index/Index.h index 4ccbef394cb97d580c2084c1e83672723565b2d2..ffa685689b12c65e45cb3a8fc9b5506392164172 100644 --- a/cpp/src/core/knowhere/knowhere/index/Index.h +++ b/cpp/src/core/knowhere/knowhere/index/Index.h @@ -15,54 +15,53 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include +#include "IndexModel.h" +#include "IndexType.h" #include "knowhere/common/BinarySet.h" #include "knowhere/common/Dataset.h" -#include "IndexType.h" -#include "IndexModel.h" #include "knowhere/index/preprocessor/Preprocessor.h" - -namespace zilliz { namespace knowhere { - class Index { public: virtual BinarySet Serialize() = 0; virtual void - Load(const BinarySet &index_binary) = 0; + Load(const BinarySet& index_binary) = 0; // @throw virtual DatasetPtr - Search(const DatasetPtr &dataset, const Config &config) = 0; + Search(const DatasetPtr& dataset, const Config& config) = 0; public: IndexType - idx_type() const { return idx_type_; } + idx_type() const { + return idx_type_; + } void - set_idx_type(IndexType idx_type) { idx_type_ = idx_type; } + set_idx_type(IndexType idx_type) { + idx_type_ = idx_type; + } virtual void - set_preprocessor(PreprocessorPtr preprocessor) {} + set_preprocessor(PreprocessorPtr preprocessor) { + } virtual void - set_index_model(IndexModelPtr model) {} + set_index_model(IndexModelPtr model) { + } private: IndexType idx_type_; }; - using IndexPtr = std::shared_ptr; - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/IndexModel.h b/cpp/src/core/knowhere/knowhere/index/IndexModel.h index 7557363f65a677056ca87bfced32de2ff5258dc8..1229e21f2cb48f6e2b17d51dc62ecdc6a4bb693d 100644 --- a/cpp/src/core/knowhere/knowhere/index/IndexModel.h +++ b/cpp/src/core/knowhere/knowhere/index/IndexModel.h @@ -15,28 +15,22 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include #include "knowhere/common/BinarySet.h" -namespace zilliz { namespace knowhere { - class IndexModel { public: virtual BinarySet Serialize() = 0; virtual void - Load(const BinarySet &binary) = 0; + Load(const BinarySet& binary) = 0; }; using IndexModelPtr = std::shared_ptr; - - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/IndexType.h b/cpp/src/core/knowhere/knowhere/index/IndexType.h index 3ece6287f5584a0a466e767b9ce028a91d6a0be8..e39b57bf44732c6b6b33bfb89399503177b75c01 100644 --- a/cpp/src/core/knowhere/knowhere/index/IndexType.h +++ b/cpp/src/core/knowhere/knowhere/index/IndexType.h @@ -15,14 +15,10 @@ // specific language governing permissions and limitations // under the License. - #pragma once - -namespace zilliz { namespace knowhere { - enum class IndexType { kUnknown = 0, kVecIdxBegin = 100, @@ -30,6 +26,4 @@ enum class IndexType { kVecIdxEnd, }; - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.cpp b/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.cpp index 41f6ad37159ef27dc3b94dfab35cba86b7c0832c..f21c787abe0ed7dad881460623fc4b6d9d975513 100644 --- a/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.cpp +++ b/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.cpp @@ -1,14 +1,30 @@ +//// Licensed to the Apache Software Foundation (ASF) under one +//// or more contributor license agreements. See the NOTICE file +//// distributed with this work for additional information +//// regarding copyright ownership. The ASF licenses this file +//// to you under the Apache License, Version 2.0 (the +//// "License"); you may not use this file except in compliance +//// with the License. You may obtain a copy of the License at +//// +//// http://www.apache.org/licenses/LICENSE-2.0 +//// +//// Unless required by applicable law or agreed to in writing, +//// software distributed under the License is distributed on an +//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +//// KIND, either express or implied. See the License for the +//// specific language governing permissions and limitations +//// under the License. // //#include "knowhere/index/vector_index/definitions.h" //#include "knowhere/common/config.h" -//#include "knowhere/index/preprocessor/normalize.h" +#include "knowhere/index/preprocessor/Normalize.h" // // -//namespace zilliz { -//namespace knowhere { // -//DatasetPtr -//NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) { +// namespace knowhere { +// +// DatasetPtr +// NormalizePreprocessor::Preprocess(const DatasetPtr &dataset) { // // TODO: wrap dataset->tensor // auto tensor = dataset->tensor()[0]; // auto p_data = (float *)tensor->raw_mutable_data(); @@ -21,8 +37,8 @@ // } //} // -//void -//NormalizePreprocessor::Normalize(float *arr, int64_t dimension) { +// void +// NormalizePreprocessor::Normalize(float *arr, int64_t dimension) { // double vector_length = 0; // for (auto j = 0; j < dimension; j++) { // double val = arr[j]; @@ -38,5 +54,4 @@ //} // //} // namespace knowhere -//} // namespace zilliz - +// diff --git a/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.h b/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.h index 399a5d349034ef0e7308448f646f22df26f8bb79..572c447b2794bd447a4edf7ea42f16ebc5344ad8 100644 --- a/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.h +++ b/cpp/src/core/knowhere/knowhere/index/preprocessor/Normalize.h @@ -1,13 +1,30 @@ +//// Licensed to the Apache Software Foundation (ASF) under one +//// or more contributor license agreements. See the NOTICE file +//// distributed with this work for additional information +//// regarding copyright ownership. The ASF licenses this file +//// to you under the Apache License, Version 2.0 (the +//// "License"); you may not use this file except in compliance +//// with the License. You may obtain a copy of the License at +//// +//// http://www.apache.org/licenses/LICENSE-2.0 +//// +//// Unless required by applicable law or agreed to in writing, +//// software distributed under the License is distributed on an +//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +//// KIND, either express or implied. See the License for the +//// specific language governing permissions and limitations +//// under the License. +// //#pragma once // //#include //#include "preprocessor.h" // // -//namespace zilliz { -//namespace knowhere { // -//class NormalizePreprocessor : public Preprocessor { +// namespace knowhere { +// +// class NormalizePreprocessor : public Preprocessor { // public: // DatasetPtr // Preprocess(const DatasetPtr &input) override; @@ -19,8 +36,8 @@ //}; // // -//using NormalizePreprocessorPtr = std::shared_ptr; +// using NormalizePreprocessorPtr = std::shared_ptr; // // //} // namespace knowhere -//} // namespace zilliz +// diff --git a/cpp/src/core/knowhere/knowhere/index/preprocessor/Preprocessor.h b/cpp/src/core/knowhere/knowhere/index/preprocessor/Preprocessor.h index e1c01d20850dc9c53bdda07a0424155706652e74..d1cb1817d6e102759bfc8e2be5083417f63c12e6 100644 --- a/cpp/src/core/knowhere/knowhere/index/preprocessor/Preprocessor.h +++ b/cpp/src/core/knowhere/knowhere/index/preprocessor/Preprocessor.h @@ -15,27 +15,20 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include #include "knowhere/common/Dataset.h" - -namespace zilliz { namespace knowhere { - class Preprocessor { public: virtual DatasetPtr - Preprocess(const DatasetPtr &input) = 0; + Preprocess(const DatasetPtr& input) = 0; }; - using PreprocessorPtr = std::shared_ptr; - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp index a3ab7956cb634ad048b98fb42f624b12a923599f..2612a63630893f6a3e85dd7bdee4d8b54d07268b 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp @@ -15,23 +15,23 @@ // specific language governing permissions and limitations // under the License. - #include -#include +#include #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/FaissBaseIndex.h" +#include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/helpers/FaissIO.h" -#include "FaissBaseIndex.h" - -namespace zilliz { namespace knowhere { -FaissBaseIndex::FaissBaseIndex(std::shared_ptr index) : index_(std::move(index)) {} +FaissBaseIndex::FaissBaseIndex(std::shared_ptr index) : index_(std::move(index)) { +} -BinarySet FaissBaseIndex::SerializeImpl() { +BinarySet +FaissBaseIndex::SerializeImpl() { try { - faiss::Index *index = index_.get(); + faiss::Index* index = index_.get(); SealImpl(); @@ -44,37 +44,37 @@ BinarySet FaissBaseIndex::SerializeImpl() { // TODO(linxj): use virtual func Name() instead of raw string. res_set.Append("IVF", data, writer.rp); return res_set; - } catch (std::exception &e) { + } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } } -void FaissBaseIndex::LoadImpl(const BinarySet &index_binary) { +void +FaissBaseIndex::LoadImpl(const BinarySet& index_binary) { auto binary = index_binary.GetByName("IVF"); MemoryIOReader reader; reader.total = binary->size; reader.data_ = binary->data.get(); - faiss::Index *index = faiss::read_index(&reader); + faiss::Index* index = faiss::read_index(&reader); index_.reset(index); } -void FaissBaseIndex::SealImpl() { -// TODO(linxj): enable -//#ifdef ZILLIZ_FAISS - faiss::Index *index = index_.get(); - auto idx = dynamic_cast(index); +void +FaissBaseIndex::SealImpl() { + // TODO(linxj): enable + //#ifdef ZILLIZ_FAISS + faiss::Index* index = index_.get(); + auto idx = dynamic_cast(index); if (idx != nullptr) { idx->to_readonly(); } - //else { + // else { // KNOHWERE_ERROR_MSG("Seal failed"); //} -//#endif + //#endif } -} // knowhere -} // zilliz - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.h b/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.h index 29edbbd61e92964f7e38c355ce75747c0a307099..f3fceebb88e9858e666ead347e263de0fb124e9a 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/FaissBaseIndex.h @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include @@ -24,8 +23,6 @@ #include "knowhere/common/BinarySet.h" - -namespace zilliz { namespace knowhere { class FaissBaseIndex { @@ -36,7 +33,7 @@ class FaissBaseIndex { SerializeImpl(); virtual void - LoadImpl(const BinarySet &index_binary); + LoadImpl(const BinarySet& index_binary); virtual void SealImpl(); @@ -45,8 +42,4 @@ class FaissBaseIndex { std::shared_ptr index_ = nullptr; }; -} // knowhere -} // zilliz - - - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.cpp index 667f8ec4f15a0ef037f3cb17663468fe6db43028..a5e8f90f34a5a65ef3fc9b7016d7ee329ce0c744 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.cpp @@ -15,30 +15,27 @@ // specific language governing permissions and limitations // under the License. - - +#include #include #include #include -#include -#include #include +#include - +#include "knowhere/adapter/VectorAdapter.h" #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexGPUIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/helpers/Cloner.h" -#include "knowhere/adapter/VectorAdapter.h" -#include "IndexGPUIVF.h" #include "knowhere/index/vector_index/helpers/FaissIO.h" - -namespace zilliz { namespace knowhere { -IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) { +IndexModelPtr +GPUIVF::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } gpu_id_ = build_cfg->gpu_id; @@ -49,10 +46,9 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) { ResScope rs(temp_resource, gpu_id_, true); faiss::gpu::GpuIndexIVFFlatConfig idx_config; idx_config.device = gpu_id_; - faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, - build_cfg->nlist, GetMetricType(build_cfg->metric_type), - idx_config); - device_index.train(rows, (float *) p_data); + faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, build_cfg->nlist, + GetMetricType(build_cfg->metric_type), idx_config); + device_index.train(rows, (float*)p_data); std::shared_ptr host_index = nullptr; host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index)); @@ -63,7 +59,8 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) { } } -void GPUIVF::set_index_model(IndexModelPtr model) { +void +GPUIVF::set_index_model(IndexModelPtr model) { std::lock_guard lk(mutex_); auto host_index = std::static_pointer_cast(model); @@ -77,7 +74,8 @@ void GPUIVF::set_index_model(IndexModelPtr model) { } } -BinarySet GPUIVF::SerializeImpl() { +BinarySet +GPUIVF::SerializeImpl() { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -85,8 +83,8 @@ BinarySet GPUIVF::SerializeImpl() { try { MemoryIOWriter writer; { - faiss::Index *index = index_.get(); - faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(index); + faiss::Index* index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index); SealImpl(); @@ -100,19 +98,20 @@ BinarySet GPUIVF::SerializeImpl() { res_set.Append("IVF", data, writer.rp); return res_set; - } catch (std::exception &e) { + } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } } -void GPUIVF::LoadImpl(const BinarySet &index_binary) { +void +GPUIVF::LoadImpl(const BinarySet& index_binary) { auto binary = index_binary.GetByName("IVF"); MemoryIOReader reader; { reader.total = binary->size; reader.data_ = binary->data.get(); - faiss::Index *index = faiss::read_index(&reader); + faiss::Index* index = faiss::read_index(&reader); if (auto temp_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) { ResScope rs(temp_res, gpu_id_, false); @@ -127,12 +126,8 @@ void GPUIVF::LoadImpl(const BinarySet &index_binary) { } } -void GPUIVF::search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg) { +void +GPUIVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) { std::lock_guard lk(mutex_); // TODO(linxj): gpu index support GenParams @@ -143,19 +138,20 @@ void GPUIVF::search_impl(int64_t n, { // TODO(linxj): allocate gpu mem ResScope rs(res_, gpu_id_); - device_index->search(n, (float *) data, k, distances, labels); + device_index->search(n, (float*)data, k, distances, labels); } } else { KNOWHERE_THROW_MSG("Not a GpuIndexIVF type."); } } -VectorIndexPtr GPUIVF::CopyGpuToCpu(const Config &config) { +VectorIndexPtr +GPUIVF::CopyGpuToCpu(const Config& config) { std::lock_guard lk(mutex_); - if ( auto device_idx = std::dynamic_pointer_cast(index_)) { - faiss::Index *device_index = index_.get(); - faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index); + if (auto device_idx = std::dynamic_pointer_cast(index_)) { + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); std::shared_ptr new_index; new_index.reset(host_index); @@ -165,33 +161,36 @@ VectorIndexPtr GPUIVF::CopyGpuToCpu(const Config &config) { } } -VectorIndexPtr GPUIVF::Clone() { +VectorIndexPtr +GPUIVF::Clone() { auto cpu_idx = CopyGpuToCpu(Config()); - return ::zilliz::knowhere::cloner::CopyCpuToGpu(cpu_idx, gpu_id_, Config()); + return knowhere::cloner::CopyCpuToGpu(cpu_idx, gpu_id_, Config()); } -VectorIndexPtr GPUIVF::CopyGpuToGpu(const int64_t &device_id, const Config &config) { +VectorIndexPtr +GPUIVF::CopyGpuToGpu(const int64_t& device_id, const Config& config) { auto host_index = CopyGpuToCpu(config); return std::static_pointer_cast(host_index)->CopyCpuToGpu(device_id, config); } -void GPUIVF::Add(const DatasetPtr &dataset, const Config &config) { +void +GPUIVF::Add(const DatasetPtr& dataset, const Config& config) { if (auto spt = res_.lock()) { ResScope rs(res_, gpu_id_); IVF::Add(dataset, config); - } - else { + } else { KNOWHERE_THROW_MSG("Add IVF can't get gpu resource"); } } -void GPUIndex::SetGpuDevice(const int &gpu_id) { +void +GPUIndex::SetGpuDevice(const int& gpu_id) { gpu_id_ = gpu_id; } -const int64_t &GPUIndex::GetGpuDevice() { +const int64_t& +GPUIndex::GetGpuDevice() { return gpu_id_; } -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.h index 82695a7e4a34ffcc55e0507b5d03b5d389f23cd1..fa9a206c48185899d7ad7af62184a2f58e35cf93 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVF.h @@ -15,81 +15,78 @@ // specific language governing permissions and limitations // under the License. - #pragma once +#include +#include #include "IndexIVF.h" #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" - -namespace zilliz { namespace knowhere { class GPUIndex { -public: - explicit GPUIndex(const int &device_id) : gpu_id_(device_id) {} + public: + explicit GPUIndex(const int& device_id) : gpu_id_(device_id) { + } - GPUIndex(const int& device_id, const ResPtr& resource): gpu_id_(device_id), res_(resource) {} + GPUIndex(const int& device_id, const ResPtr& resource) : gpu_id_(device_id), res_(resource) { + } - virtual VectorIndexPtr - CopyGpuToCpu(const Config &config) = 0; + virtual VectorIndexPtr + CopyGpuToCpu(const Config& config) = 0; - virtual VectorIndexPtr - CopyGpuToGpu(const int64_t &device_id, const Config &config) = 0; + virtual VectorIndexPtr + CopyGpuToGpu(const int64_t& device_id, const Config& config) = 0; - void - SetGpuDevice(const int &gpu_id); + void + SetGpuDevice(const int& gpu_id); - const int64_t & - GetGpuDevice(); + const int64_t& + GetGpuDevice(); -protected: - int64_t gpu_id_; - ResWPtr res_; + protected: + int64_t gpu_id_; + ResWPtr res_; }; class GPUIVF : public IVF, public GPUIndex { -public: - explicit GPUIVF(const int &device_id) : IVF(), GPUIndex(device_id) {} + public: + explicit GPUIVF(const int& device_id) : IVF(), GPUIndex(device_id) { + } - explicit GPUIVF(std::shared_ptr index, const int64_t &device_id, ResPtr &resource) - : IVF(std::move(index)), GPUIndex(device_id, resource) {}; + explicit GPUIVF(std::shared_ptr index, const int64_t& device_id, ResPtr& resource) + : IVF(std::move(index)), GPUIndex(device_id, resource) { + } - IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + IndexModelPtr + Train(const DatasetPtr& dataset, const Config& config) override; - void - Add(const DatasetPtr &dataset, const Config &config) override; + void + Add(const DatasetPtr& dataset, const Config& config) override; - void - set_index_model(IndexModelPtr model) override; + void + set_index_model(IndexModelPtr model) override; - //DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override; - VectorIndexPtr - CopyGpuToCpu(const Config &config) override; + // DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override; + VectorIndexPtr + CopyGpuToCpu(const Config& config) override; - VectorIndexPtr - CopyGpuToGpu(const int64_t &device_id, const Config &config) override; + VectorIndexPtr + CopyGpuToGpu(const int64_t& device_id, const Config& config) override; - VectorIndexPtr - Clone() final; + VectorIndexPtr + Clone() final; -protected: - void - search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg) override; + protected: + void + search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override; - BinarySet - SerializeImpl() override; + BinarySet + SerializeImpl() override; - void - LoadImpl(const BinarySet &index_binary) override; + void + LoadImpl(const BinarySet& index_binary) override; }; -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp index 03193ea60425390f7192604d946d0de4f3f1bcca..213141b3ace72615cdd8a9c9b0d9ed1a5ba809be 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp @@ -15,23 +15,23 @@ // specific language governing permissions and limitations // under the License. - -#include -#include #include +#include +#include +#include -#include "IndexGPUIVFPQ.h" -#include "knowhere/common/Exception.h" #include "knowhere/adapter/VectorAdapter.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" - -namespace zilliz { namespace knowhere { -IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) { +IndexModelPtr +GPUIVFPQ::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } gpu_id_ = build_cfg->gpu_id; @@ -40,9 +40,9 @@ IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) { // TODO(linxj): set device here. // TODO(linxj): set gpu resource here. faiss::gpu::StandardGpuResources res; - faiss::gpu::GpuIndexIVFPQ device_index(&res, dim, build_cfg->nlist, build_cfg->m, - build_cfg->nbits, GetMetricType(build_cfg->metric_type)); // IP not support - device_index.train(rows, (float *) p_data); + faiss::gpu::GpuIndexIVFPQ device_index(&res, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits, + GetMetricType(build_cfg->metric_type)); // IP not support + device_index.train(rows, (float*)p_data); std::shared_ptr host_index = nullptr; host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index)); @@ -50,20 +50,21 @@ IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) { return std::make_shared(host_index); } -std::shared_ptr GPUIVFPQ::GenParams(const Config &config) { +std::shared_ptr +GPUIVFPQ::GenParams(const Config& config) { auto params = std::make_shared(); auto search_cfg = std::dynamic_pointer_cast(config); params->nprobe = search_cfg->nprobe; -// params->scan_table_threshold = conf->scan_table_threhold; -// params->polysemous_ht = conf->polysemous_ht; -// params->max_codes = conf->max_codes; + // params->scan_table_threshold = conf->scan_table_threhold; + // params->polysemous_ht = conf->polysemous_ht; + // params->max_codes = conf->max_codes; return params; } -VectorIndexPtr GPUIVFPQ::CopyGpuToCpu(const Config &config) { +VectorIndexPtr +GPUIVFPQ::CopyGpuToCpu(const Config& config) { KNOWHERE_THROW_MSG("not support yet"); } -} // knowhere -} // zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.h index c407ee0cc7a3f1a1d2c48fa816b5c4f55b167cdb..13ea1075ca5740bc6094d0ec2e59bcdd501f3d91 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.h @@ -15,33 +15,30 @@ // specific language governing permissions and limitations // under the License. - #pragma once +#include + #include "IndexGPUIVF.h" -namespace zilliz { namespace knowhere { class GPUIVFPQ : public GPUIVF { -public: - explicit GPUIVFPQ(const int &device_id) : GPUIVF(device_id) {} + public: + explicit GPUIVFPQ(const int& device_id) : GPUIVF(device_id) { + } IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + Train(const DatasetPtr& dataset, const Config& config) override; -public: + public: VectorIndexPtr - CopyGpuToCpu(const Config &config) override; + CopyGpuToCpu(const Config& config) override; -protected: + protected: // TODO(linxj): remove GenParams. std::shared_ptr - GenParams(const Config &config) override; + GenParams(const Config& config) override; }; -} // knowhere -} // zilliz - - - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.cpp index af23267404e60cb1c2ac75bc94535bd7b31e72d3..1b4f4e9edba8edf4e2c19b5c48ea1bebed3cf432 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.cpp @@ -15,60 +15,60 @@ // specific language governing permissions and limitations // under the License. - #include +#include +#include #include "knowhere/adapter/VectorAdapter.h" #include "knowhere/common/Exception.h" -#include "IndexGPUIVFSQ.h" -#include "IndexIVFSQ.h" - +#include "knowhere/index/vector_index/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" -namespace zilliz { namespace knowhere { - IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) { - auto build_cfg = std::dynamic_pointer_cast(config); - if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception - } - gpu_id_ = build_cfg->gpu_id; +IndexModelPtr +GPUIVFSQ::Train(const DatasetPtr& dataset, const Config& config) { + auto build_cfg = std::dynamic_pointer_cast(config); + if (build_cfg != nullptr) { + build_cfg->CheckValid(); // throw exception + } + gpu_id_ = build_cfg->gpu_id; - GETTENSOR(dataset) + GETTENSOR(dataset) - std::stringstream index_type; - index_type << "IVF" << build_cfg->nlist << "," << "SQ" << build_cfg->nbits; - auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type)); + std::stringstream index_type; + index_type << "IVF" << build_cfg->nlist << "," + << "SQ" << build_cfg->nbits; + auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type)); - auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); - if (temp_resource != nullptr) { - ResScope rs(temp_resource, gpu_id_, true); - auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index); - device_index->train(rows, (float *) p_data); + auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (temp_resource != nullptr) { + ResScope rs(temp_resource, gpu_id_, true); + auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index); + device_index->train(rows, (float*)p_data); - std::shared_ptr host_index = nullptr; - host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index)); + std::shared_ptr host_index = nullptr; + host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index)); - delete device_index; - delete build_index; + delete device_index; + delete build_index; - return std::make_shared(host_index); - } else { - KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource"); - } + return std::make_shared(host_index); + } else { + KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource"); } +} - VectorIndexPtr GPUIVFSQ::CopyGpuToCpu(const Config &config) { - std::lock_guard lk(mutex_); +VectorIndexPtr +GPUIVFSQ::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); - faiss::Index *device_index = index_.get(); - faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index); - - std::shared_ptr new_index; - new_index.reset(host_index); - return std::make_shared(new_index); - } + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); -} // knowhere -} // zilliz + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); +} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.h index 677907964bbf33d0f3c6aa359c1be7c37988c38a..ed8013d77f1f87535a8cabb7ca256c7ee28bc399 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.h @@ -15,29 +15,29 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "IndexGPUIVF.h" +#include +#include +#include "IndexGPUIVF.h" -namespace zilliz { namespace knowhere { class GPUIVFSQ : public GPUIVF { -public: - explicit GPUIVFSQ(const int &device_id) : GPUIVF(device_id) {} + public: + explicit GPUIVFSQ(const int& device_id) : GPUIVF(device_id) { + } - explicit GPUIVFSQ(std::shared_ptr index, const int64_t &device_id, ResPtr &resource) - : GPUIVF(std::move(index), device_id, resource) {}; + explicit GPUIVFSQ(std::shared_ptr index, const int64_t& device_id, ResPtr& resource) + : GPUIVF(std::move(index), device_id, resource) { + } IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + Train(const DatasetPtr& dataset, const Config& config) override; VectorIndexPtr - CopyGpuToCpu(const Config &config) override; + CopyGpuToCpu(const Config& config) override; }; -} // knowhere -} // zilliz - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp index f5e3fd8f40758e5f95fafbdccbe11c0b0a75fbb5..2371591b5c179ba546489aa53c424c99b4fd92d6 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp @@ -15,24 +15,22 @@ // specific language governing permissions and limitations // under the License. - -#include #include +#include #include -#include #include +#include +#include - -#include "knowhere/common/Exception.h" #include "knowhere/adapter/VectorAdapter.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/helpers/FaissIO.h" -#include "IndexIDMAP.h" - -namespace zilliz { namespace knowhere { -BinarySet IDMAP::Serialize() { +BinarySet +IDMAP::Serialize() { if (!index_) { KNOWHERE_THROW_MSG("index not initialize"); } @@ -41,31 +39,33 @@ BinarySet IDMAP::Serialize() { return SerializeImpl(); } -void IDMAP::Load(const BinarySet &index_binary) { +void +IDMAP::Load(const BinarySet& index_binary) { std::lock_guard lk(mutex_); LoadImpl(index_binary); } -DatasetPtr IDMAP::Search(const DatasetPtr &dataset, const Config &config) { +DatasetPtr +IDMAP::Search(const DatasetPtr& dataset, const Config& config) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize"); } config->CheckValid(); - //auto metric_type = config["metric_type"].as_string() == "L2" ? + // auto metric_type = config["metric_type"].as_string() == "L2" ? // faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT; - //index_->metric_type = metric_type; + // index_->metric_type = metric_type; GETTENSOR(dataset) auto elems = rows * config->k; - auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems); - auto res_dis = (float *) malloc(sizeof(float) * elems); + auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems); + auto res_dis = (float*)malloc(sizeof(float) * elems); - search_impl(rows, (float *) p_data, config->k, res_dis, res_ids, Config()); + search_impl(rows, (float*)p_data, config->k, res_dis, res_ids, Config()); - auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems); - auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems); + auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems); + auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems); std::vector id_bufs{nullptr, id_buf}; std::vector dist_bufs{nullptr, dist_buf}; @@ -83,12 +83,13 @@ DatasetPtr IDMAP::Search(const DatasetPtr &dataset, const Config &config) { return std::make_shared(array, nullptr); } -void IDMAP::search_impl(int64_t n, const float *data, int64_t k, float *distances, int64_t *labels, const Config &cfg) { - index_->search(n, (float *) data, k, distances, labels); - +void +IDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) { + index_->search(n, (float*)data, k, distances, labels); } -void IDMAP::Add(const DatasetPtr &dataset, const Config &config) { +void +IDMAP::Add(const DatasetPtr& dataset, const Config& config) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize"); } @@ -98,49 +99,56 @@ void IDMAP::Add(const DatasetPtr &dataset, const Config &config) { // TODO: magic here. auto array = dataset->array()[0]; - auto p_ids = array->data()->GetValues(1, 0); + auto p_ids = array->data()->GetValues(1, 0); - index_->add_with_ids(rows, (float *) p_data, p_ids); + index_->add_with_ids(rows, (float*)p_data, p_ids); } -int64_t IDMAP::Count() { +int64_t +IDMAP::Count() { return index_->ntotal; } -int64_t IDMAP::Dimension() { +int64_t +IDMAP::Dimension() { return index_->d; } // TODO(linxj): return const pointer -float *IDMAP::GetRawVectors() { +float* +IDMAP::GetRawVectors() { try { - auto file_index = dynamic_cast(index_.get()); + auto file_index = dynamic_cast(index_.get()); auto flat_index = dynamic_cast(file_index->index); return flat_index->xb.data(); - } catch (std::exception &e) { + } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } } // TODO(linxj): return const pointer -int64_t *IDMAP::GetRawIds() { +int64_t* +IDMAP::GetRawIds() { try { - auto file_index = dynamic_cast(index_.get()); + auto file_index = dynamic_cast(index_.get()); return file_index->id_map.data(); - } catch (std::exception &e) { + } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } } const char* type = "IDMap,Flat"; -void IDMAP::Train(const Config &config) { + +void +IDMAP::Train(const Config& config) { config->CheckValid(); auto index = faiss::index_factory(config->d, type, GetMetricType(config->metric_type)); index_.reset(index); } -VectorIndexPtr IDMAP::Clone() { +VectorIndexPtr +IDMAP::Clone() { std::lock_guard lk(mutex_); auto clone_index = faiss::clone_index(index_.get()); @@ -149,8 +157,9 @@ VectorIndexPtr IDMAP::Clone() { return std::make_shared(new_index); } -VectorIndexPtr IDMAP::CopyCpuToGpu(const int64_t &device_id, const Config &config) { - if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){ +VectorIndexPtr +IDMAP::CopyCpuToGpu(const int64_t& device_id, const Config& config) { + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); @@ -162,38 +171,41 @@ VectorIndexPtr IDMAP::CopyCpuToGpu(const int64_t &device_id, const Config &confi } } -void IDMAP::Seal() { +void +IDMAP::Seal() { // do nothing } -VectorIndexPtr GPUIDMAP::CopyGpuToCpu(const Config &config) { +VectorIndexPtr +GPUIDMAP::CopyGpuToCpu(const Config& config) { std::lock_guard lk(mutex_); - faiss::Index *device_index = index_.get(); - faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index); + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); std::shared_ptr new_index; new_index.reset(host_index); return std::make_shared(new_index); } -VectorIndexPtr GPUIDMAP::Clone() { +VectorIndexPtr +GPUIDMAP::Clone() { auto cpu_idx = CopyGpuToCpu(Config()); - if (auto idmap = std::dynamic_pointer_cast(cpu_idx)){ + if (auto idmap = std::dynamic_pointer_cast(cpu_idx)) { return idmap->CopyCpuToGpu(gpu_id_, Config()); - } - else { + } else { KNOWHERE_THROW_MSG("IndexType not Support GpuClone"); } } -BinarySet GPUIDMAP::SerializeImpl() { +BinarySet +GPUIDMAP::SerializeImpl() { try { MemoryIOWriter writer; { - faiss::Index *index = index_.get(); - faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(index); + faiss::Index* index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index); faiss::write_index(host_index, &writer); delete host_index; @@ -205,21 +217,22 @@ BinarySet GPUIDMAP::SerializeImpl() { res_set.Append("IVF", data, writer.rp); return res_set; - } catch (std::exception &e) { + } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } } -void GPUIDMAP::LoadImpl(const BinarySet &index_binary) { +void +GPUIDMAP::LoadImpl(const BinarySet& index_binary) { auto binary = index_binary.GetByName("IVF"); MemoryIOReader reader; { reader.total = binary->size; reader.data_ = binary->data.get(); - faiss::Index *index = faiss::read_index(&reader); + faiss::Index* index = faiss::read_index(&reader); - if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_) ){ + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) { ResScope rs(res, gpu_id_, false); auto device_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index); index_.reset(device_index); @@ -232,28 +245,26 @@ void GPUIDMAP::LoadImpl(const BinarySet &index_binary) { } } -VectorIndexPtr GPUIDMAP::CopyGpuToGpu(const int64_t &device_id, const Config &config) { +VectorIndexPtr +GPUIDMAP::CopyGpuToGpu(const int64_t& device_id, const Config& config) { auto cpu_index = CopyGpuToCpu(config); return std::static_pointer_cast(cpu_index)->CopyCpuToGpu(device_id, config); } -float *GPUIDMAP::GetRawVectors() { +float* +GPUIDMAP::GetRawVectors() { KNOWHERE_THROW_MSG("Not support"); } -int64_t *GPUIDMAP::GetRawIds() { +int64_t* +GPUIDMAP::GetRawIds() { KNOWHERE_THROW_MSG("Not support"); } -void GPUIDMAP::search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg) { +void +GPUIDMAP::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) { ResScope rs(res_, gpu_id_); - index_->search(n, (float *) data, k, distances, labels); + index_->search(n, (float*)data, k, distances, labels); } -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.h index 106759faf1c6380243c3ce769422eacb8d0c49a3..ec1cbb9e770735d4c92f91492085a37b1783e5a2 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIDMAP.h @@ -15,41 +15,53 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "IndexIVF.h" #include "IndexGPUIVF.h" +#include "IndexIVF.h" +#include +#include -namespace zilliz { namespace knowhere { class IDMAP : public VectorIndex, public FaissBaseIndex { public: - IDMAP() : FaissBaseIndex(nullptr) {}; - explicit IDMAP(std::shared_ptr index) : FaissBaseIndex(std::move(index)) {}; - BinarySet Serialize() override; - void Load(const BinarySet &index_binary) override; - void Train(const Config &config); - DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override; - int64_t Count() override; - VectorIndexPtr Clone() override; - int64_t Dimension() override; - void Add(const DatasetPtr &dataset, const Config &config) override; - VectorIndexPtr CopyCpuToGpu(const int64_t &device_id, const Config &config); - void Seal() override; + IDMAP() : FaissBaseIndex(nullptr) { + } + + explicit IDMAP(std::shared_ptr index) : FaissBaseIndex(std::move(index)) { + } + + BinarySet + Serialize() override; + void + Load(const BinarySet& index_binary) override; + void + Train(const Config& config); + DatasetPtr + Search(const DatasetPtr& dataset, const Config& config) override; + int64_t + Count() override; + VectorIndexPtr + Clone() override; + int64_t + Dimension() override; + void + Add(const DatasetPtr& dataset, const Config& config) override; + VectorIndexPtr + CopyCpuToGpu(const int64_t& device_id, const Config& config); + void + Seal() override; - virtual float *GetRawVectors(); - virtual int64_t *GetRawIds(); + virtual float* + GetRawVectors(); + virtual int64_t* + GetRawIds(); protected: - virtual void search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg); + virtual void + search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg); std::mutex mutex_; }; @@ -57,27 +69,30 @@ using IDMAPPtr = std::shared_ptr; class GPUIDMAP : public IDMAP, public GPUIndex { public: - explicit GPUIDMAP(std::shared_ptr index, const int64_t &device_id, ResPtr& res) - : IDMAP(std::move(index)), GPUIndex(device_id, res) {} + explicit GPUIDMAP(std::shared_ptr index, const int64_t& device_id, ResPtr& res) + : IDMAP(std::move(index)), GPUIndex(device_id, res) { + } - VectorIndexPtr CopyGpuToCpu(const Config &config) override; - float *GetRawVectors() override; - int64_t *GetRawIds() override; - VectorIndexPtr Clone() override; - VectorIndexPtr CopyGpuToGpu(const int64_t &device_id, const Config &config) override; + VectorIndexPtr + CopyGpuToCpu(const Config& config) override; + float* + GetRawVectors() override; + int64_t* + GetRawIds() override; + VectorIndexPtr + Clone() override; + VectorIndexPtr + CopyGpuToGpu(const int64_t& device_id, const Config& config) override; protected: - void search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg) override; - BinarySet SerializeImpl() override; - void LoadImpl(const BinarySet &index_binary) override; + void + search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override; + BinarySet + SerializeImpl() override; + void + LoadImpl(const BinarySet& index_binary) override; }; using GPUIDMAPPtr = std::shared_ptr; -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.cpp index b8f5ee2811ab7b41dc7bff15c9c4d5aa4bdf348e..510ab46bd6de3bd0683a0a78cf433835a5c8144b 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.cpp @@ -15,47 +15,46 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include #include #include #include #include -#include -#include -#include -#include #include +#include +#include +#include +#include - -#include "knowhere/common/Exception.h" #include "knowhere/adapter/VectorAdapter.h" -#include "IndexIVF.h" -#include "IndexGPUIVF.h" - +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexGPUIVF.h" +#include "knowhere/index/vector_index/IndexIVF.h" -namespace zilliz { namespace knowhere { - -IndexModelPtr IVF::Train(const DatasetPtr &dataset, const Config &config) { +IndexModelPtr +IVF::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } GETTENSOR(dataset) - faiss::Index *coarse_quantizer = new faiss::IndexFlatL2(dim); - auto index = std::make_shared(coarse_quantizer, dim, - build_cfg->nlist, + faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim); + auto index = std::make_shared(coarse_quantizer, dim, build_cfg->nlist, GetMetricType(build_cfg->metric_type)); - index->train(rows, (float *) p_data); + index->train(rows, (float*)p_data); // TODO(linxj): override here. train return model or not. return std::make_shared(index); } - -void IVF::Add(const DatasetPtr &dataset, const Config &config) { +void +IVF::Add(const DatasetPtr& dataset, const Config& config) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -64,11 +63,12 @@ void IVF::Add(const DatasetPtr &dataset, const Config &config) { GETTENSOR(dataset) auto array = dataset->array()[0]; - auto p_ids = array->data()->GetValues(1, 0); - index_->add_with_ids(rows, (float *) p_data, p_ids); + auto p_ids = array->data()->GetValues(1, 0); + index_->add_with_ids(rows, (float*)p_data, p_ids); } -void IVF::AddWithoutIds(const DatasetPtr &dataset, const Config &config) { +void +IVF::AddWithoutIds(const DatasetPtr& dataset, const Config& config) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -76,10 +76,11 @@ void IVF::AddWithoutIds(const DatasetPtr &dataset, const Config &config) { std::lock_guard lk(mutex_); GETTENSOR(dataset) - index_->add(rows, (float *) p_data); + index_->add(rows, (float*)p_data); } -BinarySet IVF::Serialize() { +BinarySet +IVF::Serialize() { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } @@ -89,31 +90,33 @@ BinarySet IVF::Serialize() { return SerializeImpl(); } -void IVF::Load(const BinarySet &index_binary) { +void +IVF::Load(const BinarySet& index_binary) { std::lock_guard lk(mutex_); LoadImpl(index_binary); } -DatasetPtr IVF::Search(const DatasetPtr &dataset, const Config &config) { +DatasetPtr +IVF::Search(const DatasetPtr& dataset, const Config& config) { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } auto search_cfg = std::dynamic_pointer_cast(config); if (search_cfg != nullptr) { - search_cfg->CheckValid(); // throw exception + search_cfg->CheckValid(); // throw exception } GETTENSOR(dataset) auto elems = rows * search_cfg->k; - auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems); - auto res_dis = (float *) malloc(sizeof(float) * elems); + auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems); + auto res_dis = (float*)malloc(sizeof(float) * elems); - search_impl(rows, (float*) p_data, search_cfg->k, res_dis, res_ids, config); + search_impl(rows, (float*)p_data, search_cfg->k, res_dis, res_ids, config); - auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems); - auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems); + auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems); + auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems); std::vector id_bufs{nullptr, id_buf}; std::vector dist_bufs{nullptr, dist_buf}; @@ -131,7 +134,8 @@ DatasetPtr IVF::Search(const DatasetPtr &dataset, const Config &config) { return std::make_shared(array, nullptr); } -void IVF::set_index_model(IndexModelPtr model) { +void +IVF::set_index_model(IndexModelPtr model) { std::lock_guard lk(mutex_); auto rel_model = std::static_pointer_cast(model); @@ -140,25 +144,29 @@ void IVF::set_index_model(IndexModelPtr model) { index_.reset(faiss::clone_index(rel_model->index_.get())); } -std::shared_ptr IVF::GenParams(const Config &config) { +std::shared_ptr +IVF::GenParams(const Config& config) { auto params = std::make_shared(); auto search_cfg = std::dynamic_pointer_cast(config); params->nprobe = search_cfg->nprobe; - //params->max_codes = config.get_with_default("max_codes", size_t(0)); + // params->max_codes = config.get_with_default("max_codes", size_t(0)); return params; } -int64_t IVF::Count() { +int64_t +IVF::Count() { return index_->ntotal; } -int64_t IVF::Dimension() { +int64_t +IVF::Dimension() { return index_->d; } -void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, const Config &config) { +void +IVF::GenGraph(const int64_t& k, Graph& graph, const DatasetPtr& dataset, const Config& config) { GETTENSOR(dataset) auto ntotal = Count(); @@ -174,7 +182,7 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co for (int i = 0; i < total_search_count; ++i) { auto b_size = i == total_search_count - 1 && tail_batch_size != 0 ? tail_batch_size : batch_size; - auto &res = res_vec[i]; + auto& res = res_vec[i]; res.resize(k * b_size); auto xq = p_data + batch_size * dim * i; @@ -182,7 +190,7 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co int tmp = 0; for (int j = 0; j < b_size; ++j) { - auto &node = graph[batch_size * i + j]; + auto& node = graph[batch_size * i + j]; node.resize(k); for (int m = 0; m < k && tmp < k * b_size; ++m, ++tmp) { // TODO(linxj): avoid memcopy here. @@ -192,18 +200,15 @@ void IVF::GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, co } } -void IVF::search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg) { +void +IVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) { auto params = GenParams(cfg); - faiss::ivflib::search_with_parameters(index_.get(), n, (float *) data, k, distances, labels, params.get()); + faiss::ivflib::search_with_parameters(index_.get(), n, (float*)data, k, distances, labels, params.get()); } -VectorIndexPtr IVF::CopyCpuToGpu(const int64_t& device_id, const Config &config) { - if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){ +VectorIndexPtr +IVF::CopyCpuToGpu(const int64_t& device_id, const Config& config) { + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); @@ -215,7 +220,8 @@ VectorIndexPtr IVF::CopyCpuToGpu(const int64_t& device_id, const Config &config) } } -VectorIndexPtr IVF::Clone() { +VectorIndexPtr +IVF::Clone() { std::lock_guard lk(mutex_); auto clone_index = faiss::clone_index(index_.get()); @@ -224,21 +230,24 @@ VectorIndexPtr IVF::Clone() { return Clone_impl(new_index); } -VectorIndexPtr IVF::Clone_impl(const std::shared_ptr &index) { +VectorIndexPtr +IVF::Clone_impl(const std::shared_ptr& index) { return std::make_shared(index); } -void IVF::Seal() { +void +IVF::Seal() { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } SealImpl(); } +IVFIndexModel::IVFIndexModel(std::shared_ptr index) : FaissBaseIndex(std::move(index)) { +} -IVFIndexModel::IVFIndexModel(std::shared_ptr index) : FaissBaseIndex(std::move(index)) {} - -BinarySet IVFIndexModel::Serialize() { +BinarySet +IVFIndexModel::Serialize() { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("indexmodel not initialize or trained"); } @@ -246,18 +255,15 @@ BinarySet IVFIndexModel::Serialize() { return SerializeImpl(); } -void IVFIndexModel::Load(const BinarySet &binary_set) { +void +IVFIndexModel::Load(const BinarySet& binary_set) { std::lock_guard lk(mutex_); LoadImpl(binary_set); } -void IVFIndexModel::SealImpl() { +void +IVFIndexModel::SealImpl() { // do nothing } - - - - -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.h index 72f98a5afab5979d576355d33c8be194d86ef822..ef9982fa300dfe24c4afaf4986c7fde0735834cb 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVF.h @@ -15,54 +15,55 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include #include +#include +#include -#include "VectorIndex.h" #include "FaissBaseIndex.h" +#include "VectorIndex.h" #include "faiss/IndexIVF.h" - -namespace zilliz { namespace knowhere { using Graph = std::vector>; class IVF : public VectorIndex, protected FaissBaseIndex { public: - IVF() : FaissBaseIndex(nullptr) {}; + IVF() : FaissBaseIndex(nullptr) { + } - explicit IVF(std::shared_ptr index) : FaissBaseIndex(std::move(index)) {} + explicit IVF(std::shared_ptr index) : FaissBaseIndex(std::move(index)) { + } VectorIndexPtr - Clone() override;; + Clone() override; IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + Train(const DatasetPtr& dataset, const Config& config) override; void set_index_model(IndexModelPtr model) override; void - Add(const DatasetPtr &dataset, const Config &config) override; + Add(const DatasetPtr& dataset, const Config& config) override; void - AddWithoutIds(const DatasetPtr &dataset, const Config &config); + AddWithoutIds(const DatasetPtr& dataset, const Config& config); DatasetPtr - Search(const DatasetPtr &dataset, const Config &config) override; + Search(const DatasetPtr& dataset, const Config& config) override; void - GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, const Config &config); + GenGraph(const int64_t& k, Graph& graph, const DatasetPtr& dataset, const Config& config); BinarySet Serialize() override; void - Load(const BinarySet &index_binary) override; + Load(const BinarySet& index_binary) override; int64_t Count() override; @@ -74,23 +75,17 @@ class IVF : public VectorIndex, protected FaissBaseIndex { Seal() override; virtual VectorIndexPtr - CopyCpuToGpu(const int64_t &device_id, const Config &config); - + CopyCpuToGpu(const int64_t& device_id, const Config& config); protected: virtual std::shared_ptr - GenParams(const Config &config); + GenParams(const Config& config); virtual VectorIndexPtr - Clone_impl(const std::shared_ptr &index); + Clone_impl(const std::shared_ptr& index); virtual void - search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg); + search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg); protected: std::mutex mutex_; @@ -106,13 +101,14 @@ class IVFIndexModel : public IndexModel, public FaissBaseIndex { public: explicit IVFIndexModel(std::shared_ptr index); - IVFIndexModel() : FaissBaseIndex(nullptr) {}; + IVFIndexModel() : FaissBaseIndex(nullptr) { + } BinarySet Serialize() override; void - Load(const BinarySet &binary) override; + Load(const BinarySet& binary) override; protected: void @@ -121,7 +117,7 @@ class IVFIndexModel : public IndexModel, public FaissBaseIndex { protected: std::mutex mutex_; }; + using IVFIndexModelPtr = std::shared_ptr; -} -} \ No newline at end of file +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp index b5418b77cfec6efe570767bef811fda2c6747e49..03acbf31d774730b9514bd6b46d93866f384b5da 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp @@ -15,47 +15,49 @@ // specific language governing permissions and limitations // under the License. - #include #include +#include +#include -#include "IndexIVFPQ.h" -#include "knowhere/common/Exception.h" #include "knowhere/adapter/VectorAdapter.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" -namespace zilliz { namespace knowhere { -IndexModelPtr IVFPQ::Train(const DatasetPtr &dataset, const Config &config) { +IndexModelPtr +IVFPQ::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } GETTENSOR(dataset) - faiss::Index *coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type)); - auto index = std::make_shared(coarse_quantizer, dim, - build_cfg->nlist, build_cfg->m, build_cfg->nbits); - index->train(rows, (float *) p_data); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(build_cfg->metric_type)); + auto index = + std::make_shared(coarse_quantizer, dim, build_cfg->nlist, build_cfg->m, build_cfg->nbits); + index->train(rows, (float*)p_data); return std::make_shared(index); } -std::shared_ptr IVFPQ::GenParams(const Config &config) { +std::shared_ptr +IVFPQ::GenParams(const Config& config) { auto params = std::make_shared(); auto search_cfg = std::dynamic_pointer_cast(config); params->nprobe = search_cfg->nprobe; -// params->scan_table_threshold = conf->scan_table_threhold; -// params->polysemous_ht = conf->polysemous_ht; -// params->max_codes = conf->max_codes; + // params->scan_table_threshold = conf->scan_table_threhold; + // params->polysemous_ht = conf->polysemous_ht; + // params->max_codes = conf->max_codes; return params; } -VectorIndexPtr IVFPQ::Clone_impl(const std::shared_ptr &index) { +VectorIndexPtr +IVFPQ::Clone_impl(const std::shared_ptr& index) { return std::make_shared(index); } -} // knowhere -} // zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.h index 427f76fe3e634edab425facb332de2c47705835d..69aaa5090bd46a5cc115f8f7bd177da4135bbcdd 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFPQ.h @@ -15,33 +15,31 @@ // specific language governing permissions and limitations // under the License. - #pragma once +#include +#include + #include "IndexIVF.h" -namespace zilliz { namespace knowhere { class IVFPQ : public IVF { -public: - explicit IVFPQ(std::shared_ptr index) : IVF(std::move(index)) {} + public: + explicit IVFPQ(std::shared_ptr index) : IVF(std::move(index)) { + } IVFPQ() = default; IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + Train(const DatasetPtr& dataset, const Config& config) override; -protected: + protected: std::shared_ptr - GenParams(const Config &config) override; + GenParams(const Config& config) override; VectorIndexPtr - Clone_impl(const std::shared_ptr &index) override; + Clone_impl(const std::shared_ptr& index) override; }; -} // knowhere -} // zilliz - - - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp index ec29627ebc8b256eaa1706453caeb9f69b0d510f..11e508549ac3f0363605fb3ccba93ce6851e8aa8 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp @@ -15,44 +15,45 @@ // specific language governing permissions and limitations // under the License. - #include +#include +#include "knowhere/adapter/VectorAdapter.h" #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" -#include "knowhere/adapter/VectorAdapter.h" -#include "IndexIVFSQ.h" -#include "IndexGPUIVFSQ.h" - -namespace zilliz { namespace knowhere { -IndexModelPtr IVFSQ::Train(const DatasetPtr &dataset, const Config &config) { +IndexModelPtr +IVFSQ::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } GETTENSOR(dataset) std::stringstream index_type; - index_type << "IVF" << build_cfg->nlist << "," << "SQ" << build_cfg->nbits; - auto build_index = faiss::index_factory(dim, index_type.str().c_str(), - GetMetricType(build_cfg->metric_type)); - build_index->train(rows, (float *) p_data); + index_type << "IVF" << build_cfg->nlist << "," + << "SQ" << build_cfg->nbits; + auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type)); + build_index->train(rows, (float*)p_data); std::shared_ptr ret_index; ret_index.reset(build_index); return std::make_shared(ret_index); } -VectorIndexPtr IVFSQ::Clone_impl(const std::shared_ptr &index) { +VectorIndexPtr +IVFSQ::Clone_impl(const std::shared_ptr& index) { return std::make_shared(index); } -VectorIndexPtr IVFSQ::CopyCpuToGpu(const int64_t &device_id, const Config &config) { - if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){ +VectorIndexPtr +IVFSQ::CopyCpuToGpu(const int64_t& device_id, const Config& config) { + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); faiss::gpu::GpuClonerOptions option; option.allInGpu = true; @@ -67,5 +68,4 @@ VectorIndexPtr IVFSQ::CopyCpuToGpu(const int64_t &device_id, const Config &confi } } -} // knowhere -} // zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.h index 5ee771f951c8d7b8ae3a7803f26f25e14852b45e..cac95faebf10d6e4f187f46ae8e61b5f2c95343e 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQ.h @@ -15,31 +15,31 @@ // specific language governing permissions and limitations // under the License. - #pragma once +#include +#include + #include "IndexIVF.h" -namespace zilliz { namespace knowhere { class IVFSQ : public IVF { -public: - explicit IVFSQ(std::shared_ptr index) : IVF(std::move(index)) {} + public: + explicit IVFSQ(std::shared_ptr index) : IVF(std::move(index)) { + } IVFSQ() = default; IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + Train(const DatasetPtr& dataset, const Config& config) override; VectorIndexPtr - CopyCpuToGpu(const int64_t &device_id, const Config &config) override; + CopyCpuToGpu(const int64_t& device_id, const Config& config) override; -protected: + protected: VectorIndexPtr - Clone_impl(const std::shared_ptr &index) override; + Clone_impl(const std::shared_ptr& index) override; }; -} // knowhere -} // zilliz - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp index c36a5d07fa61b2d1be6db34dcb12a0e63adb2ee9..60b1770fc3bba556985151d559a97e1c06ff6fdb 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp @@ -16,36 +16,35 @@ // specific language governing permissions and limitations // under the License. +#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" +#include "faiss/AutoTune.h" +#include "faiss/gpu/GpuAutoTune.h" #include "faiss/gpu/GpuIndexIVF.h" #include "knowhere/adapter/VectorAdapter.h" #include "knowhere/common/Exception.h" -#include "IndexIVFSQHybrid.h" -#include "faiss/AutoTune.h" -#include "faiss/gpu/GpuAutoTune.h" - -namespace zilliz { namespace knowhere { IndexModelPtr -IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) { +IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } gpu_id_ = build_cfg->gpu_id; GETTENSOR(dataset) std::stringstream index_type; - index_type << "IVF" << build_cfg->nlist << "," << "SQ8Hybrid"; + index_type << "IVF" << build_cfg->nlist << "," + << "SQ8Hybrid"; auto build_index = faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(build_cfg->metric_type)); auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); if (temp_resource != nullptr) { ResScope rs(temp_resource, gpu_id_, true); auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_id_, build_index); - device_index->train(rows, (float *) p_data); + device_index->train(rows, (float*)p_data); std::shared_ptr host_index = nullptr; host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index)); @@ -60,12 +59,12 @@ IVFSQHybrid::Train(const DatasetPtr &dataset, const Config &config) { } VectorIndexPtr -IVFSQHybrid::CopyGpuToCpu(const Config &config) { +IVFSQHybrid::CopyGpuToCpu(const Config& config) { std::lock_guard lk(mutex_); if (auto device_idx = std::dynamic_pointer_cast(index_)) { - faiss::Index *device_index = index_.get(); - faiss::Index *host_index = faiss::gpu::index_gpu_to_cpu(device_index); + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); std::shared_ptr new_index; new_index.reset(host_index); @@ -77,7 +76,7 @@ IVFSQHybrid::CopyGpuToCpu(const Config &config) { } VectorIndexPtr -IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) { +IVFSQHybrid::CopyCpuToGpu(const int64_t& device_id, const Config& config) { if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { ResScope rs(res, device_id, false); faiss::gpu::GpuClonerOptions option; @@ -86,7 +85,7 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) { faiss::IndexComposition index_composition; index_composition.index = index_.get(); index_composition.quantizer = nullptr; - index_composition.mode = 0; // copy all + index_composition.mode = 0; // copy all auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, &index_composition, &option); @@ -99,17 +98,13 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t &device_id, const Config &config) { } void -IVFSQHybrid::LoadImpl(const BinarySet &index_binary) { - FaissBaseIndex::LoadImpl(index_binary); // load on cpu +IVFSQHybrid::LoadImpl(const BinarySet& index_binary) { + FaissBaseIndex::LoadImpl(index_binary); // load on cpu } void -IVFSQHybrid::search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg) { +IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, + const Config& cfg) { if (gpu_mode) { GPUIVF::search_impl(n, data, k, distances, labels, cfg); } else { @@ -119,10 +114,10 @@ IVFSQHybrid::search_impl(int64_t n, } QuantizerPtr -IVFSQHybrid::LoadQuantizer(const Config &conf) { +IVFSQHybrid::LoadQuantizer(const Config& conf) { auto quantizer_conf = std::dynamic_pointer_cast(conf); if (quantizer_conf != nullptr) { - if(quantizer_conf->mode != 1) { + if (quantizer_conf->mode != 1) { KNOWHERE_THROW_MSG("mode only support 1 in this func"); } } @@ -136,7 +131,7 @@ IVFSQHybrid::LoadQuantizer(const Config &conf) { auto index_composition = new faiss::IndexComposition; index_composition->index = index_.get(); index_composition->quantizer = nullptr; - index_composition->mode = quantizer_conf->mode; // only 1 + index_composition->mode = quantizer_conf->mode; // only 1 auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option); delete gpu_index; @@ -157,10 +152,9 @@ IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) { KNOWHERE_THROW_MSG("Quantizer type error"); } - faiss::IndexIVF *ivf_index = - dynamic_cast(index_.get()); + faiss::IndexIVF* ivf_index = dynamic_cast(index_.get()); - faiss::gpu::GpuIndexFlat *is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); if (is_gpu_flat_index == nullptr) { delete ivf_index->quantizer; ivf_index->quantizer = ivf_quantizer->quantizer; @@ -169,8 +163,8 @@ IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) { void IVFSQHybrid::UnsetQuantizer() { - auto *ivf_index = dynamic_cast(index_.get()); - if(ivf_index == nullptr) { + auto* ivf_index = dynamic_cast(index_.get()); + if (ivf_index == nullptr) { KNOWHERE_THROW_MSG("Index type error"); } @@ -178,10 +172,10 @@ IVFSQHybrid::UnsetQuantizer() { } void -IVFSQHybrid::LoadData(const knowhere::QuantizerPtr &q, const Config &conf) { +IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { auto quantizer_conf = std::dynamic_pointer_cast(conf); if (quantizer_conf != nullptr) { - if(quantizer_conf->mode != 2) { + if (quantizer_conf->mode != 2) { KNOWHERE_THROW_MSG("mode only support 2 in this func"); } } @@ -195,20 +189,20 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr &q, const Config &conf) { option.allInGpu = true; auto ivf_quantizer = std::dynamic_pointer_cast(q); - if (ivf_quantizer == nullptr) KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer"); + if (ivf_quantizer == nullptr) + KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer"); auto index_composition = new faiss::IndexComposition; index_composition->index = index_.get(); index_composition->quantizer = ivf_quantizer->quantizer; - index_composition->mode = quantizer_conf->mode; // only 2 + index_composition->mode = quantizer_conf->mode; // only 2 auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option); index_.reset(gpu_index); - gpu_mode = true; // all in gpu + gpu_mode = true; // all in gpu } else { KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); } } -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h index 687eeda73cefafae487ed1002703903e6a0a2167..1ec67760ff89ee951ecf44707817b1c1a272f107 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.h @@ -15,27 +15,24 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include #include +#include #include "IndexGPUIVFSQ.h" #include "Quantizer.h" - -namespace zilliz { namespace knowhere { struct FaissIVFQuantizer : public Quantizer { - faiss::gpu::GpuIndexFlat *quantizer = nullptr; + faiss::gpu::GpuIndexFlat* quantizer = nullptr; }; using FaissIVFQuantizerPtr = std::shared_ptr; class IVFSQHybrid : public GPUIVFSQ { public: - explicit IVFSQHybrid(const int &device_id) : GPUIVFSQ(device_id) { + explicit IVFSQHybrid(const int& device_id) : GPUIVFSQ(device_id) { gpu_mode = false; } @@ -44,14 +41,14 @@ class IVFSQHybrid : public GPUIVFSQ { gpu_mode = false; } - explicit IVFSQHybrid(std::shared_ptr index, const int64_t &device_id, ResPtr &resource) + explicit IVFSQHybrid(std::shared_ptr index, const int64_t& device_id, ResPtr& resource) : GPUIVFSQ(index, device_id, resource) { gpu_mode = true; } public: QuantizerPtr - LoadQuantizer(const Config &conf); + LoadQuantizer(const Config& conf); void SetQuantizer(const QuantizerPtr& q); @@ -60,31 +57,26 @@ class IVFSQHybrid : public GPUIVFSQ { UnsetQuantizer(); void - LoadData(const knowhere::QuantizerPtr &q, const Config& conf); + LoadData(const knowhere::QuantizerPtr& q, const Config& conf); IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + Train(const DatasetPtr& dataset, const Config& config) override; VectorIndexPtr - CopyGpuToCpu(const Config &config) override; + CopyGpuToCpu(const Config& config) override; VectorIndexPtr - CopyCpuToGpu(const int64_t &device_id, const Config &config) override; + CopyCpuToGpu(const int64_t& device_id, const Config& config) override; protected: void - search_impl(int64_t n, - const float *data, - int64_t k, - float *distances, - int64_t *labels, - const Config &cfg) override; + search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) override; - void LoadImpl(const BinarySet &index_binary) override; + void + LoadImpl(const BinarySet& index_binary) override; protected: bool gpu_mode = false; }; -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.cpp index cceb88516ac5e84857e94dff3131187deaa447a4..c23a6ef61d81298b2a53f66889addf05561cc2da 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.cpp @@ -15,41 +15,38 @@ // specific language governing permissions and limitations // under the License. - -#include -#include -#include #include - +#include +#include +#include +#include #undef mkdir -#include "IndexKDT.h" +#include "knowhere/index/vector_index/IndexKDT.h" #include "knowhere/index/vector_index/helpers/Definitions.h" //#include "knowhere/index/preprocessor/normalize.h" -#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h" #include "knowhere/adapter/SptagAdapter.h" #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h" - -namespace zilliz { namespace knowhere { BinarySet CPUKDTRNG::Serialize() { - std::vector index_blobs; + std::vector index_blobs; std::vector index_len; index_ptr_->SaveIndexToMemory(index_blobs, index_len); BinarySet binary_set; auto sample = std::make_shared(); - sample.reset(static_cast(index_blobs[0])); + sample.reset(static_cast(index_blobs[0])); auto tree = std::make_shared(); - tree.reset(static_cast(index_blobs[1])); + tree.reset(static_cast(index_blobs[1])); auto graph = std::make_shared(); - graph.reset(static_cast(index_blobs[2])); + graph.reset(static_cast(index_blobs[2])); auto metadata = std::make_shared(); - metadata.reset(static_cast(index_blobs[3])); + metadata.reset(static_cast(index_blobs[3])); binary_set.Append("samples", sample, index_len[0]); binary_set.Append("tree", tree, index_len[1]); @@ -59,8 +56,8 @@ CPUKDTRNG::Serialize() { } void -CPUKDTRNG::Load(const BinarySet &binary_set) { - std::vector index_blobs; +CPUKDTRNG::Load(const BinarySet& binary_set) { + std::vector index_blobs; auto samples = binary_set.GetByName("samples"); index_blobs.push_back(samples->data.get()); @@ -77,17 +74,17 @@ CPUKDTRNG::Load(const BinarySet &binary_set) { index_ptr_->LoadIndexFromMemory(index_blobs); } -//PreprocessorPtr -//CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) { +// PreprocessorPtr +// CPUKDTRNG::BuildPreprocessor(const DatasetPtr &dataset, const Config &config) { // return std::make_shared(); //} IndexModelPtr -CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) { +CPUKDTRNG::Train(const DatasetPtr& origin, const Config& train_config) { SetParameters(train_config); DatasetPtr dataset = origin->Clone(); - //if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine + // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine // && preprocessor_) { // preprocessor_->Preprocess(dataset); //} @@ -101,11 +98,11 @@ CPUKDTRNG::Train(const DatasetPtr &origin, const Config &train_config) { } void -CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) { +CPUKDTRNG::Add(const DatasetPtr& origin, const Config& add_config) { SetParameters(add_config); DatasetPtr dataset = origin->Clone(); - //if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine + // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine // && preprocessor_) { // preprocessor_->Preprocess(dataset); //} @@ -116,18 +113,18 @@ CPUKDTRNG::Add(const DatasetPtr &origin, const Config &add_config) { } void -CPUKDTRNG::SetParameters(const Config &config) { - for (auto ¶ : KDTParameterMgr::GetInstance().GetKDTParameters()) { -// auto value = config.get_with_default(para.first, para.second); +CPUKDTRNG::SetParameters(const Config& config) { + for (auto& para : KDTParameterMgr::GetInstance().GetKDTParameters()) { + // auto value = config.get_with_default(para.first, para.second); index_ptr_->SetParameter(para.first, para.second); } } DatasetPtr -CPUKDTRNG::Search(const DatasetPtr &dataset, const Config &config) { +CPUKDTRNG::Search(const DatasetPtr& dataset, const Config& config) { SetParameters(config); auto tensor = dataset->tensor()[0]; - auto p = (float *) tensor->raw_mutable_data(); + auto p = (float*)tensor->raw_mutable_data(); for (auto i = 0; i < 10; ++i) { for (auto j = 0; j < 10; ++j) { std::cout << p[i * 10 + j] << " "; @@ -138,7 +135,7 @@ CPUKDTRNG::Search(const DatasetPtr &dataset, const Config &config) { #pragma omp parallel for for (auto i = 0; i < query_results.size(); ++i) { - auto target = (float *) query_results[i].GetTarget(); + auto target = (float*)query_results[i].GetTarget(); std::cout << target[0] << ", " << target[1] << ", " << target[2] << std::endl; index_ptr_->SearchIndex(query_results[i]); } @@ -146,27 +143,33 @@ CPUKDTRNG::Search(const DatasetPtr &dataset, const Config &config) { return ConvertToDataset(query_results); } -int64_t CPUKDTRNG::Count() { +int64_t +CPUKDTRNG::Count() { index_ptr_->GetNumSamples(); } -int64_t CPUKDTRNG::Dimension() { + +int64_t +CPUKDTRNG::Dimension() { index_ptr_->GetFeatureDim(); } -VectorIndexPtr CPUKDTRNG::Clone() { +VectorIndexPtr +CPUKDTRNG::Clone() { KNOWHERE_THROW_MSG("not support"); } -void CPUKDTRNG::Seal() { +void +CPUKDTRNG::Seal() { // do nothing } // TODO(linxj): BinarySet -CPUKDTRNGIndexModel::Serialize() {} +CPUKDTRNGIndexModel::Serialize() { +} void -CPUKDTRNGIndexModel::Load(const BinarySet &binary) {} +CPUKDTRNGIndexModel::Load(const BinarySet& binary) { +} -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.h index 7fd41e9baddb67875710ee6125109fbbef7c6523..f6d436995b0d1051bf8c75c2c7a51cae4eb1d890 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexKDT.h @@ -15,53 +15,53 @@ // specific language governing permissions and limitations // under the License. - #pragma once +#include #include #include #include "VectorIndex.h" #include "knowhere/index/IndexModel.h" -#include - -namespace zilliz { namespace knowhere { - class CPUKDTRNG : public VectorIndex { public: CPUKDTRNG() { - index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT, - SPTAG::VectorValueType::Float); + index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT, SPTAG::VectorValueType::Float); index_ptr_->SetParameter("DistCalcMethod", "L2"); } public: BinarySet Serialize() override; - VectorIndexPtr Clone() override; + VectorIndexPtr + Clone() override; void - Load(const BinarySet &index_array) override; + Load(const BinarySet& index_array) override; public: - //PreprocessorPtr - //BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override; - int64_t Count() override; - int64_t Dimension() override; + // PreprocessorPtr + // BuildPreprocessor(const DatasetPtr &dataset, const Config &config) override; + int64_t + Count() override; + int64_t + Dimension() override; IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) override; + Train(const DatasetPtr& dataset, const Config& config) override; void - Add(const DatasetPtr &dataset, const Config &config) override; + Add(const DatasetPtr& dataset, const Config& config) override; DatasetPtr - Search(const DatasetPtr &dataset, const Config &config) override; - void Seal() override; + Search(const DatasetPtr& dataset, const Config& config) override; + void + Seal() override; + private: void - SetParameters(const Config &config); + SetParameters(const Config& config); private: PreprocessorPtr preprocessor_; @@ -76,7 +76,7 @@ class CPUKDTRNGIndexModel : public IndexModel { Serialize() override; void - Load(const BinarySet &binary) override; + Load(const BinarySet& binary) override; private: std::shared_ptr index_; @@ -84,5 +84,4 @@ class CPUKDTRNGIndexModel : public IndexModel { using CPUKDTRNGIndexModelPtr = std::shared_ptr; -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.cpp index 38a7fed23bac6bd34800b87495d25214e393101f..f5519b824089b94885cbfb05ffdabcb6c63a1b36 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.cpp @@ -15,28 +15,26 @@ // specific language governing permissions and limitations // under the License. - -#include "IndexNSG.h" -#include "knowhere/index/vector_index/nsg/NSG.h" -#include "knowhere/index/vector_index/nsg/NSGIO.h" -#include "IndexIDMAP.h" -#include "IndexIVF.h" -#include "IndexGPUIVF.h" +#include "knowhere/index/vector_index/IndexNSG.h" #include "knowhere/adapter/VectorAdapter.h" #include "knowhere/common/Exception.h" #include "knowhere/common/Timer.h" +#include "knowhere/index/vector_index/IndexGPUIVF.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/nsg/NSG.h" +#include "knowhere/index/vector_index/nsg/NSGIO.h" - -namespace zilliz { namespace knowhere { -BinarySet NSG::Serialize() { +BinarySet +NSG::Serialize() { if (!index_ || !index_->is_trained) { KNOWHERE_THROW_MSG("index not initialize or trained"); } try { - algo::NsgIndex *index = index_.get(); + algo::NsgIndex* index = index_.get(); MemoryIOWriter writer; algo::write_index(index, writer); @@ -46,12 +44,13 @@ BinarySet NSG::Serialize() { BinarySet res_set; res_set.Append("NSG", data, writer.total); return res_set; - } catch (std::exception &e) { + } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } } -void NSG::Load(const BinarySet &index_binary) { +void +NSG::Load(const BinarySet& index_binary) { try { auto binary = index_binary.GetByName("NSG"); @@ -61,15 +60,16 @@ void NSG::Load(const BinarySet &index_binary) { auto index = algo::read_index(reader); index_.reset(index); - } catch (std::exception &e) { + } catch (std::exception& e) { KNOWHERE_THROW_MSG(e.what()); } } -DatasetPtr NSG::Search(const DatasetPtr &dataset, const Config &config) { +DatasetPtr +NSG::Search(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } if (!index_ || !index_->is_trained) { @@ -79,16 +79,15 @@ DatasetPtr NSG::Search(const DatasetPtr &dataset, const Config &config) { GETTENSOR(dataset) auto elems = rows * build_cfg->k; - auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems); - auto res_dis = (float *) malloc(sizeof(float) * elems); + auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems); + auto res_dis = (float*)malloc(sizeof(float) * elems); algo::SearchParams s_params; s_params.search_length = build_cfg->search_length; - index_->Search((float *) p_data, rows, dim, - build_cfg->k, res_dis, res_ids, s_params); + index_->Search((float*)p_data, rows, dim, build_cfg->k, res_dis, res_ids, s_params); - auto id_buf = MakeMutableBufferSmart((uint8_t *) res_ids, sizeof(int64_t) * elems); - auto dist_buf = MakeMutableBufferSmart((uint8_t *) res_dis, sizeof(float) * elems); + auto id_buf = MakeMutableBufferSmart((uint8_t*)res_ids, sizeof(int64_t) * elems); + auto dist_buf = MakeMutableBufferSmart((uint8_t*)res_dis, sizeof(float) * elems); std::vector id_bufs{nullptr, id_buf}; std::vector dist_bufs{nullptr, dist_buf}; @@ -106,10 +105,11 @@ DatasetPtr NSG::Search(const DatasetPtr &dataset, const Config &config) { return std::make_shared(array, nullptr); } -IndexModelPtr NSG::Train(const DatasetPtr &dataset, const Config &config) { +IndexModelPtr +NSG::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); if (build_cfg != nullptr) { - build_cfg->CheckValid(); // throw exception + build_cfg->CheckValid(); // throw exception } if (build_cfg->metric_type != METRICTYPE::L2) { @@ -132,34 +132,37 @@ IndexModelPtr NSG::Train(const DatasetPtr &dataset, const Config &config) { GETTENSOR(dataset) auto array = dataset->array()[0]; - auto p_ids = array->data()->GetValues(1, 0); + auto p_ids = array->data()->GetValues(1, 0); index_ = std::make_shared(dim, rows); index_->SetKnnGraph(knng); - index_->Build_with_ids(rows, (float *) p_data, (long *) p_ids, b_params); - return nullptr; // TODO(linxj): support serialize + index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params); + return nullptr; // TODO(linxj): support serialize } -void NSG::Add(const DatasetPtr &dataset, const Config &config) { +void +NSG::Add(const DatasetPtr& dataset, const Config& config) { // do nothing } -int64_t NSG::Count() { +int64_t +NSG::Count() { return index_->ntotal; } -int64_t NSG::Dimension() { +int64_t +NSG::Dimension() { return index_->dimension; } -VectorIndexPtr NSG::Clone() { +VectorIndexPtr +NSG::Clone() { KNOWHERE_THROW_MSG("not support"); } -void NSG::Seal() { +void +NSG::Seal() { // do nothing } -} -} - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.h b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.h index 873bba2287d6773844886bfcc623b1c92bccb47b..04a146d58a3aa7c0f704a2c586f119c9765f4df9 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/IndexNSG.h @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "VectorIndex.h" +#include +#include +#include "VectorIndex.h" -namespace zilliz { namespace knowhere { namespace algo { @@ -30,18 +30,30 @@ class NsgIndex; class NSG : public VectorIndex { public: - explicit NSG(const int64_t& gpu_num):gpu_(gpu_num){} + explicit NSG(const int64_t& gpu_num) : gpu_(gpu_num) { + } + NSG() = default; - IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override; - DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override; - void Add(const DatasetPtr &dataset, const Config &config) override; - BinarySet Serialize() override; - void Load(const BinarySet &index_binary) override; - int64_t Count() override; - int64_t Dimension() override; - VectorIndexPtr Clone() override; - void Seal() override; + IndexModelPtr + Train(const DatasetPtr& dataset, const Config& config) override; + DatasetPtr + Search(const DatasetPtr& dataset, const Config& config) override; + void + Add(const DatasetPtr& dataset, const Config& config) override; + BinarySet + Serialize() override; + void + Load(const BinarySet& index_binary) override; + int64_t + Count() override; + int64_t + Dimension() override; + VectorIndexPtr + Clone() override; + void + Seal() override; + private: std::shared_ptr index_; int64_t gpu_; @@ -49,5 +61,4 @@ class NSG : public VectorIndex { using NSGIndexPtr = std::shared_ptr(); -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/Quantizer.h b/cpp/src/core/knowhere/knowhere/index/vector_index/Quantizer.h index b90ef570a0831dca8784cbcb0e94525726b7a13c..ea74e97c8265d5068d1f1d658573c85b3ae615f4 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/Quantizer.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/Quantizer.h @@ -18,9 +18,8 @@ #pragma once #include +#include "knowhere/common/Config.h" - -namespace zilliz { namespace knowhere { struct Quantizer { @@ -29,9 +28,8 @@ struct Quantizer { using QuantizerPtr = std::shared_ptr; struct QuantizerCfg : Cfg { - uint64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data + uint64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data }; using QuantizerConfig = std::shared_ptr; -} -} +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/VectorIndex.h b/cpp/src/core/knowhere/knowhere/index/vector_index/VectorIndex.h index 908989d15de405e52b7d5f50b0ba8e73b82b56bf..810c4d2ea4874b9f9d86e7da2e39c692b6727387 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/VectorIndex.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/VectorIndex.h @@ -15,36 +15,35 @@ // specific language governing permissions and limitations // under the License. - #pragma once - #include #include "knowhere/common/Config.h" -#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "knowhere/common/Dataset.h" #include "knowhere/index/Index.h" #include "knowhere/index/preprocessor/Preprocessor.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" - -namespace zilliz { namespace knowhere { class VectorIndex; using VectorIndexPtr = std::shared_ptr; - class VectorIndex : public Index { public: virtual PreprocessorPtr - BuildPreprocessor(const DatasetPtr &dataset, const Config &config) { return nullptr; } + BuildPreprocessor(const DatasetPtr& dataset, const Config& config) { + return nullptr; + } virtual IndexModelPtr - Train(const DatasetPtr &dataset, const Config &config) { return nullptr; } + Train(const DatasetPtr& dataset, const Config& config) { + return nullptr; + } virtual void - Add(const DatasetPtr &dataset, const Config &config) = 0; + Add(const DatasetPtr& dataset, const Config& config) = 0; virtual void Seal() = 0; @@ -59,7 +58,4 @@ class VectorIndex : public Index { Dimension() = 0; }; - - -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp index 13d4523705e6e2c82622484d039f9b3b0783d300..5ff2bfc2e323ebb669b47d57305342c519af5691 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp @@ -15,22 +15,20 @@ // specific language governing permissions and limitations // under the License. - +#include "knowhere/index/vector_index/helpers/Cloner.h" #include "knowhere/common/Exception.h" -#include "knowhere/index/vector_index/IndexIVF.h" -#include "knowhere/index/vector_index/IndexIVFSQ.h" -#include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexGPUIVF.h" #include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" #include "knowhere/index/vector_index/IndexIVFSQHybrid.h" -#include "Cloner.h" - -namespace zilliz { namespace knowhere { namespace cloner { -VectorIndexPtr CopyGpuToCpu(const VectorIndexPtr &index, const Config &config) { +VectorIndexPtr +CopyGpuToCpu(const VectorIndexPtr& index, const Config& config) { if (auto device_index = std::dynamic_pointer_cast(index)) { return device_index->CopyGpuToCpu(config); } else { @@ -38,7 +36,8 @@ VectorIndexPtr CopyGpuToCpu(const VectorIndexPtr &index, const Config &config) { } } -VectorIndexPtr CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_id, const Config &config) { +VectorIndexPtr +CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config& config) { if (auto device_index = std::dynamic_pointer_cast(index)) { return device_index->CopyCpuToGpu(device_id, config); } @@ -60,6 +59,5 @@ VectorIndexPtr CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_i } } -} // cloner -} -} +} // namespace cloner +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.h b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.h index a72e4df141ffc2b35a220175b7a66a04f0a9b355..3134238ed86db8c4810c83522d96ed2a65f29244 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Cloner.h @@ -15,23 +15,19 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "knowhere/index/vector_index/VectorIndex.h" - -namespace zilliz { namespace knowhere { namespace cloner { // TODO(linxj): rename CopyToGpu extern VectorIndexPtr -CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_id, const Config &config); +CopyCpuToGpu(const VectorIndexPtr& index, const int64_t& device_id, const Config& config); extern VectorIndexPtr -CopyGpuToCpu(const VectorIndexPtr &index, const Config &config); +CopyGpuToCpu(const VectorIndexPtr& index, const Config& config); -} // cloner -} // knowhere -} // zilliz \ No newline at end of file +} // namespace cloner +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Definitions.h b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Definitions.h index fef7b01221f792800b4eba2dc73456cadf70cf6a..f39a28d20fd8a6bb972e5f2e89e60706d51e28bd 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Definitions.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/Definitions.h @@ -15,11 +15,8 @@ // specific language governing permissions and limitations // under the License. - #pragma once - -namespace zilliz { namespace knowhere { namespace definition { @@ -27,6 +24,5 @@ namespace definition { #define META_DIM ("dimension") #define META_K ("k") -} // definition -} // knowhere -} // zilliz +} // namespace definition +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp index c50594b4b07b3c3088b2b9c68418b571b734bff0..d74c6bc562e1b79555fef20f5d7e2f5ad24dc5d0 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp @@ -15,25 +15,23 @@ // specific language governing permissions and limitations // under the License. +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" -#include "FaissGpuResourceMgr.h" +#include - -namespace zilliz { namespace knowhere { -FaissGpuResourceMgr &FaissGpuResourceMgr::GetInstance() { +FaissGpuResourceMgr& +FaissGpuResourceMgr::GetInstance() { static FaissGpuResourceMgr instance; return instance; } -void FaissGpuResourceMgr::AllocateTempMem(ResPtr &resource, - const int64_t &device_id, - const int64_t &size) { +void +FaissGpuResourceMgr::AllocateTempMem(ResPtr& resource, const int64_t& device_id, const int64_t& size) { if (size) { resource->faiss_res->setTempMemory(size); - } - else { + } else { auto search = devices_params_.find(device_id); if (search != devices_params_.end()) { resource->faiss_res->setTempMemory(search->second.temp_mem_size); @@ -42,10 +40,8 @@ void FaissGpuResourceMgr::AllocateTempMem(ResPtr &resource, } } -void FaissGpuResourceMgr::InitDevice(int64_t device_id, - int64_t pin_mem_size, - int64_t temp_mem_size, - int64_t res_num) { +void +FaissGpuResourceMgr::InitDevice(int64_t device_id, int64_t pin_mem_size, int64_t temp_mem_size, int64_t res_num) { DeviceParams params; params.pinned_mem_size = pin_mem_size; params.temp_mem_size = temp_mem_size; @@ -54,23 +50,25 @@ void FaissGpuResourceMgr::InitDevice(int64_t device_id, devices_params_.emplace(device_id, params); } -void FaissGpuResourceMgr::InitResource() { - if(is_init) return ; +void +FaissGpuResourceMgr::InitResource() { + if (is_init) + return; is_init = true; - //std::cout << "InitResource" << std::endl; - for(auto& device : devices_params_) { + // std::cout << "InitResource" << std::endl; + for (auto& device : devices_params_) { auto& device_id = device.first; mutex_cache_.emplace(device_id, std::make_unique()); - //std::cout << "Device Id: " << device_id << std::endl; + // std::cout << "Device Id: " << device_id << std::endl; auto& device_param = device.second; auto& bq = idle_map_[device_id]; for (int64_t i = 0; i < device_param.resource_num; ++i) { - //std::cout << "Resource Id: " << i << std::endl; + // std::cout << "Resource Id: " << i << std::endl; auto raw_resource = std::make_shared(); // TODO(linxj): enable set pinned memory @@ -80,11 +78,11 @@ void FaissGpuResourceMgr::InitResource() { bq.Put(res_wrapper); } } - //std::cout << "End initResource" << std::endl; + // std::cout << "End initResource" << std::endl; } -ResPtr FaissGpuResourceMgr::GetRes(const int64_t &device_id, - const int64_t &alloc_size) { +ResPtr +FaissGpuResourceMgr::GetRes(const int64_t& device_id, const int64_t& alloc_size) { InitResource(); auto finder = idle_map_.find(device_id); @@ -97,7 +95,8 @@ ResPtr FaissGpuResourceMgr::GetRes(const int64_t &device_id, return nullptr; } -void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const ResPtr &res) { +void +FaissGpuResourceMgr::MoveToIdle(const int64_t& device_id, const ResPtr& res) { auto finder = idle_map_.find(device_id); if (finder != idle_map_.end()) { auto& bq = finder->second; @@ -105,8 +104,9 @@ void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const ResPtr &res } } -void FaissGpuResourceMgr::Free() { - for (auto &item : idle_map_) { +void +FaissGpuResourceMgr::Free() { + for (auto& item : idle_map_) { auto& bq = item.second; while (!bq.Empty()) { bq.Take(); @@ -117,12 +117,10 @@ void FaissGpuResourceMgr::Free() { void FaissGpuResourceMgr::Dump() { - for (auto &item : idle_map_) { + for (auto& item : idle_map_) { auto& bq = item.second; - std::cout << "device_id: " << item.first - << ", resource count:" << bq.Size(); + std::cout << "device_id: " << item.first << ", resource count:" << bq.Size(); } } -} // knowhere -} // zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h index c1585e1be1ebc8c002a0c080ef7e201fd8c0d694..73f959baa0215d820e12ea3696934698aaf65aad 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h @@ -15,22 +15,21 @@ // specific language governing permissions and limitations // under the License. - #pragma once +#include #include #include -#include +#include #include #include "src/utils/BlockingQueue.h" -namespace zilliz { namespace knowhere { struct Resource { - explicit Resource(std::shared_ptr &r) : faiss_res(r) { + explicit Resource(std::shared_ptr& r) : faiss_res(r) { static int64_t global_id = 0; id = global_id++; } @@ -43,19 +42,19 @@ using ResPtr = std::shared_ptr; using ResWPtr = std::weak_ptr; class FaissGpuResourceMgr { -public: + public: friend class ResScope; - using ResBQ = zilliz::milvus::server::BlockingQueue; + using ResBQ = milvus::server::BlockingQueue; -public: + public: struct DeviceParams { int64_t temp_mem_size = 0; int64_t pinned_mem_size = 0; int64_t resource_num = 2; }; -public: - static FaissGpuResourceMgr & + public: + static FaissGpuResourceMgr& GetInstance(); // Free gpu resource, avoid cudaGetDevice error when deallocate. @@ -64,72 +63,71 @@ public: Free(); void - AllocateTempMem(ResPtr &resource, const int64_t& device_id, const int64_t& size); + AllocateTempMem(ResPtr& resource, const int64_t& device_id, const int64_t& size); void - InitDevice(int64_t device_id, - int64_t pin_mem_size = 0, - int64_t temp_mem_size = 0, - int64_t res_num = 2); + InitDevice(int64_t device_id, int64_t pin_mem_size = 0, int64_t temp_mem_size = 0, int64_t res_num = 2); void InitResource(); // allocate gpu memory invoke by build or copy_to_gpu ResPtr - GetRes(const int64_t &device_id, const int64_t& alloc_size = 0); + GetRes(const int64_t& device_id, const int64_t& alloc_size = 0); void - MoveToIdle(const int64_t &device_id, const ResPtr& res); + MoveToIdle(const int64_t& device_id, const ResPtr& res); void Dump(); -protected: + protected: bool is_init = false; - std::map> mutex_cache_; + std::map> mutex_cache_; std::map devices_params_; std::map idle_map_; }; class ResScope { -public: - ResScope(ResPtr &res, const int64_t& device_id, const bool& isown) - : resource(res), device_id(device_id), move(true), own(isown) { + public: + ResScope(ResPtr& res, const int64_t& device_id, const bool& isown) + : resource(res), device_id(device_id), move(true), own(isown) { Lock(); } - ResScope(ResWPtr &res, const int64_t& device_id, const bool& isown) + ResScope(ResWPtr& res, const int64_t& device_id, const bool& isown) : resource(res), device_id(device_id), move(true), own(isown) { Lock(); } // specif for search // get the ownership of gpuresource and gpu - ResScope(ResWPtr &res, const int64_t &device_id) - :device_id(device_id),move(false),own(true) { + ResScope(ResWPtr& res, const int64_t& device_id) : device_id(device_id), move(false), own(true) { resource = res.lock(); Lock(); } - void Lock() { - if (own) FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->lock(); + void + Lock() { + if (own) + FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->lock(); resource->mutex.lock(); } ~ResScope() { - if (own) FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->unlock(); - if (move) FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource); + if (own) + FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->unlock(); + if (move) + FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource); resource->mutex.unlock(); } -private: - ResPtr resource; // hold resource until deconstruct + private: + ResPtr resource; // hold resource until deconstruct int64_t device_id; bool move = true; bool own = false; }; -} // knowhere -} // zilliz \ No newline at end of file +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp index 48167fe92577fbb224c76143529adffbab6358b0..99532137c08ed25b556fe6d473da736c4a9ac2df 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp @@ -15,51 +15,53 @@ // specific language governing permissions and limitations // under the License. - #include -#include "FaissIO.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" -namespace zilliz { namespace knowhere { // TODO(linxj): Get From Config File static size_t magic_num = 2; -size_t MemoryIOWriter::operator()(const void *ptr, size_t size, size_t nitems) { + +size_t +MemoryIOWriter::operator()(const void* ptr, size_t size, size_t nitems) { auto total_need = size * nitems + rp; - if (!data_) { // data == nullptr + if (!data_) { // data == nullptr total = total_need * magic_num; rp = size * nitems; data_ = new uint8_t[total]; - memcpy((void *) (data_), ptr, rp); + memcpy((void*)(data_), ptr, rp); } if (total_need > total) { total = total_need * magic_num; auto new_data = new uint8_t[total]; - memcpy((void *) new_data, (void *) data_, rp); + memcpy((void*)new_data, (void*)data_, rp); delete data_; data_ = new_data; - memcpy((void *) (data_ + rp), ptr, size * nitems); + memcpy((void*)(data_ + rp), ptr, size * nitems); rp = total_need; } else { - memcpy((void *) (data_ + rp), ptr, size * nitems); + memcpy((void*)(data_ + rp), ptr, size * nitems); rp = total_need; } return nitems; } -size_t MemoryIOReader::operator()(void *ptr, size_t size, size_t nitems) { - if (rp >= total) return 0; +size_t +MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) { + if (rp >= total) + return 0; size_t nremain = (total - rp) / size; - if (nremain < nitems) nitems = nremain; - memcpy(ptr, (void *) (data_ + rp), size * nitems); + if (nremain < nitems) + nitems = nremain; + memcpy(ptr, (void*)(data_ + rp), size * nitems); rp += size * nitems; return nitems; } -} // knowhere -} // zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.h b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.h index cc089ec367f09762a166b4d5fb4e72c7dec80b37..7cce5bbbaca21ef5b6113bd814e11ebcdf4746bd 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/FaissIO.h @@ -15,34 +15,28 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include -namespace zilliz { namespace knowhere { struct MemoryIOWriter : public faiss::IOWriter { - uint8_t *data_ = nullptr; + uint8_t* data_ = nullptr; size_t total = 0; size_t rp = 0; size_t - operator()(const void *ptr, size_t size, size_t nitems) override; + operator()(const void* ptr, size_t size, size_t nitems) override; }; struct MemoryIOReader : public faiss::IOReader { - uint8_t *data_; + uint8_t* data_; size_t rp = 0; size_t total = 0; size_t - operator()(void *ptr, size_t size, size_t nitems) override; + operator()(void* ptr, size_t size, size_t nitems) override; }; -} // knowhere -} // zilliz - - - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp index 503f0c2f67f81bf52a619242923454b434f6a8e5..20f3388174a1572b4b35d82ab0ba451fa4c6f774 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. - -#include "IndexParameter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "knowhere/common/Exception.h" #include -namespace zilliz { namespace knowhere { -faiss::MetricType GetMetricType(METRICTYPE &type) { - if (type == METRICTYPE::L2) return faiss::METRIC_L2; - if (type == METRICTYPE::IP) return faiss::METRIC_INNER_PRODUCT; - if (type == METRICTYPE::INVALID) KNOWHERE_THROW_MSG("Metric type is invalid"); -} - +faiss::MetricType +GetMetricType(METRICTYPE& type) { + if (type == METRICTYPE::L2) { + return faiss::METRIC_L2; + } + if (type == METRICTYPE::IP) { + return faiss::METRIC_INNER_PRODUCT; + } + KNOWHERE_THROW_MSG("Metric type is invalid"); } -} + +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h index 521754e2e344929d46dc0d00c4db51537a013bcc..b2854abef8611edfb017873a17386c006fdf591b 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "knowhere/common/Config.h" #include +#include +#include "knowhere/common/Config.h" -namespace zilliz { namespace knowhere { -extern faiss::MetricType GetMetricType(METRICTYPE &type); +extern faiss::MetricType +GetMetricType(METRICTYPE& type); // IVF Config constexpr int64_t DEFAULT_NLIST = INVALID_VALUE; @@ -46,18 +46,17 @@ struct IVFCfg : public Cfg { int64_t nlist = DEFAULT_NLIST; int64_t nprobe = DEFAULT_NPROBE; - IVFCfg(const int64_t &dim, - const int64_t &k, - const int64_t &gpu_id, - const int64_t &nlist, - const int64_t &nprobe, + IVFCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe, METRICTYPE type) - : nlist(nlist), nprobe(nprobe), Cfg(dim, k, gpu_id, type) {} + : Cfg(dim, k, gpu_id, type), nlist(nlist), nprobe(nprobe) { + } IVFCfg() = default; bool - CheckValid() override {}; + CheckValid() override { + return true; + }; }; using IVFConfig = std::shared_ptr; @@ -65,45 +64,40 @@ struct IVFSQCfg : public IVFCfg { // TODO(linxj): cpu only support SQ4 SQ6 SQ8 SQ16, gpu only support SQ4, SQ8, SQ16 int64_t nbits = DEFAULT_NBITS; - IVFSQCfg(const int64_t &dim, - const int64_t &k, - const int64_t &gpu_id, - const int64_t &nlist, - const int64_t &nprobe, - const int64_t &nbits, - METRICTYPE type) - : nbits(nbits), IVFCfg(dim, k, gpu_id, nlist, nprobe, type) {} + IVFSQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe, + const int64_t& nbits, METRICTYPE type) + : IVFCfg(dim, k, gpu_id, nlist, nprobe, type), nbits(nbits) { + } IVFSQCfg() = default; bool - CheckValid() override {}; + CheckValid() override { + return true; + }; }; using IVFSQConfig = std::shared_ptr; struct IVFPQCfg : public IVFCfg { - int64_t m = DEFAULT_NSUBVECTORS; // number of subquantizers(subvector) - int64_t nbits = DEFAULT_NBITS; // number of bit per subvector index + int64_t m = DEFAULT_NSUBVECTORS; // number of subquantizers(subvector) + int64_t nbits = DEFAULT_NBITS; // number of bit per subvector index // TODO(linxj): not use yet int64_t scan_table_threhold = DEFAULT_SCAN_TABLE_THREHOLD; int64_t polysemous_ht = DEFAULT_POLYSEMOUS_HT; int64_t max_codes = DEFAULT_MAX_CODES; - IVFPQCfg(const int64_t &dim, - const int64_t &k, - const int64_t &gpu_id, - const int64_t &nlist, - const int64_t &nprobe, - const int64_t &nbits, - const int64_t &m, - METRICTYPE type) - : nbits(nbits), m(m), IVFCfg(dim, k, gpu_id, nlist, nprobe, type) {} + IVFPQCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe, + const int64_t& nbits, const int64_t& m, METRICTYPE type) + : IVFCfg(dim, k, gpu_id, nlist, nprobe, type), m(m), nbits(nbits) { + } IVFPQCfg() = default; bool - CheckValid() override {}; + CheckValid() override { + return true; + }; }; using IVFPQConfig = std::shared_ptr; @@ -113,23 +107,22 @@ struct NSGCfg : public IVFCfg { int64_t out_degree = DEFAULT_OUT_DEGREE; int64_t candidate_pool_size = DEFAULT_CANDIDATE_SISE; - NSGCfg(const int64_t &dim, - const int64_t &k, - const int64_t &gpu_id, - const int64_t &nlist, - const int64_t &nprobe, - const int64_t &knng, - const int64_t &search_length, - const int64_t &out_degree, - const int64_t &candidate_size, + NSGCfg(const int64_t& dim, const int64_t& k, const int64_t& gpu_id, const int64_t& nlist, const int64_t& nprobe, + const int64_t& knng, const int64_t& search_length, const int64_t& out_degree, const int64_t& candidate_size, METRICTYPE type) - : knng(knng), search_length(search_length), out_degree(out_degree), candidate_pool_size(candidate_size), - IVFCfg(dim, k, gpu_id, nlist, nprobe, type) {} + : IVFCfg(dim, k, gpu_id, nlist, nprobe, type), + knng(knng), + search_length(search_length), + out_degree(out_degree), + candidate_pool_size(candidate_size) { + } NSGCfg() = default; bool - CheckValid() override {}; + CheckValid() override { + return true; + }; }; using NSGConfig = std::shared_ptr; @@ -137,6 +130,4 @@ struct KDTCfg : public Cfg { int64_t tptnubmber = -1; }; -} // knowhere -} // zilliz - +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp index 3d9de271b727c0be6bd9ee0d81dce43b14592df4..19bf070dbae735dd5142c823cf824f382f64c670 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.cpp @@ -15,16 +15,13 @@ // specific language governing permissions and limitations // under the License. - #include -#include "KDTParameterMgr.h" - +#include "knowhere/index/vector_index/helpers/KDTParameterMgr.h" -namespace zilliz { namespace knowhere { -const std::vector & +const std::vector& KDTParameterMgr::GetKDTParameters() { return kdt_parameters_; } @@ -55,5 +52,4 @@ KDTParameterMgr::KDTParameterMgr() { }; } -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.h b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.h index 4b43e642e2877774a356c3802090f6642e4d885e..fe90761e173194e87e5e28cdc921e8f93df48107 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/helpers/KDTParameterMgr.h @@ -15,32 +15,33 @@ // specific language governing permissions and limitations // under the License. - #pragma once +#include #include +#include #include - -namespace zilliz { namespace knowhere { using KDTParameter = std::pair; class KDTParameterMgr { public: - const std::vector & + const std::vector& GetKDTParameters(); public: - static KDTParameterMgr & + static KDTParameterMgr& GetInstance() { static KDTParameterMgr instance; return instance; } - KDTParameterMgr(const KDTParameterMgr &) = delete; - KDTParameterMgr &operator=(const KDTParameterMgr &) = delete; + KDTParameterMgr(const KDTParameterMgr&) = delete; + KDTParameterMgr& + operator=(const KDTParameterMgr&) = delete; + private: KDTParameterMgr(); @@ -48,5 +49,4 @@ class KDTParameterMgr { std::vector kdt_parameters_; }; -} // namespace knowhere -} // namespace zilliz +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.cpp index c2b5122d655278372ae379baf186b957b8fb4000..375664fbf818ee8d2c949111d52e75278a3c314d 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.cpp @@ -15,28 +15,27 @@ // specific language governing permissions and limitations // under the License. -#include +#include #include -#include +#include #include +#include #include -#include +#include -#include "NSG.h" #include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" #include "knowhere/common/Timer.h" -#include "NSGHelper.h" +#include "knowhere/index/vector_index/nsg/NSG.h" +#include "knowhere/index/vector_index/nsg/NSGHelper.h" // TODO: enable macro //#include - -namespace zilliz { namespace knowhere { namespace algo { - -NsgIndex::NsgIndex(const size_t &dimension, const size_t &n, MetricType metric) +NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, MetricType metric) : dimension(dimension), ntotal(n), metric_type(metric) { } @@ -45,16 +44,17 @@ NsgIndex::~NsgIndex() { delete[] ids_; } -//void NsgIndex::Build(size_t nb, const float *data, const BuildParam ¶meters) { +// void NsgIndex::Build(size_t nb, const float *data, const BuildParam ¶meters) { //} -void NsgIndex::Build_with_ids(size_t nb, const float *data, const long *ids, const BuildParams ¶meters) { +void +NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters) { TimeRecorder rc("NSG"); ntotal = nb; ori_data_ = new float[ntotal * dimension]; - ids_ = new long[ntotal]; - memcpy((void *) ori_data_, (void *) data, sizeof(float) * ntotal * dimension); - memcpy((void *) ids_, (void *) ids, sizeof(long) * ntotal); + ids_ = new int64_t[ntotal]; + memcpy((void*)ori_data_, (void*)data, sizeof(float) * ntotal * dimension); + memcpy((void*)ids_, (void*)ids, sizeof(int64_t) * ntotal); search_length = parameters.search_length; out_degree = parameters.out_degree; @@ -68,8 +68,8 @@ void NsgIndex::Build_with_ids(size_t nb, const float *data, const long *ids, con //>> Debug code ///// - //int count = 0; - //for (int i = 0; i < ntotal; ++i) { + // int count = 0; + // for (int i = 0; i < ntotal; ++i) { // count += nsg[i].size(); //} ///// @@ -80,17 +80,19 @@ void NsgIndex::Build_with_ids(size_t nb, const float *data, const long *ids, con //>> Debug code /// int total_degree = 0; - for (int i = 0; i < ntotal; ++i) { + for (size_t i = 0; i < ntotal; ++i) { total_degree += nsg[i].size(); } - std::cout << "graph physical size: " << total_degree * sizeof(node_t) / 1024 / 1024; - std::cout << "average degree: " << total_degree / ntotal; + + KNOWHERE_LOG_DEBUG << "Graph physical size: " << total_degree * sizeof(node_t) / 1024 / 1024 << "m"; + KNOWHERE_LOG_DEBUG << "Average degree: " << total_degree / ntotal; ///// is_trained = true; } -void NsgIndex::InitNavigationPoint() { +void +NsgIndex::InitNavigationPoint() { // calculate the center of vectors auto center = new float[dimension]; memset(center, 0, sizeof(float) * dimension); @@ -106,11 +108,12 @@ void NsgIndex::InitNavigationPoint() { // select navigation point std::vector resset, fullset; - navigation_point = rand() % ntotal; // random initialize navigating point + unsigned int seed = 100; + navigation_point = rand_r(&seed) % ntotal; // random initialize navigating point //>> Debug code ///// - //navigation_point = drand48(); + // navigation_point = drand48(); ///// GetNeighbors(center, resset, knng); @@ -118,22 +121,21 @@ void NsgIndex::InitNavigationPoint() { //>> Debug code ///// - //std::cout << "ep: " << navigation_point << std::endl; + // std::cout << "ep: " << navigation_point << std::endl; ///// //>> Debug code ///// - //float r1 = calculate(center, ori_data_ + navigation_point * dimension, dimension); - //assert(r1 == resset[0].distance); + // float r1 = calculate(center, ori_data_ + navigation_point * dimension, dimension); + // assert(r1 == resset[0].distance); ///// } // Specify Link -void NsgIndex::GetNeighbors(const float *query, - std::vector &resset, - std::vector &fullset, - boost::dynamic_bitset<> &has_calculated_dist) { - auto &graph = knng; +void +NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::vector& fullset, + boost::dynamic_bitset<>& has_calculated_dist) { + auto& graph = knng; size_t buffer_size = search_length; if (buffer_size > ntotal) { @@ -154,9 +156,12 @@ void NsgIndex::GetNeighbors(const float *query, has_calculated_dist[init_ids[i]] = true; ++count; } + + unsigned int seed = 100; while (count < buffer_size) { - node_t id = rand() % ntotal; - if (has_calculated_dist[id]) continue; // duplicate id + node_t id = rand_r(&seed) % ntotal; + if (has_calculated_dist[id]) + continue; // duplicate id init_ids.push_back(id); ++count; has_calculated_dist[id] = true; @@ -170,7 +175,7 @@ void NsgIndex::GetNeighbors(const float *query, for (size_t i = 0; i < init_ids.size(); ++i) { node_t id = init_ids[i]; - if (id >= ntotal) { + if (id >= static_cast(ntotal)) { KNOWHERE_THROW_MSG("Build Index Error, id > ntotal"); continue; } @@ -182,9 +187,9 @@ void NsgIndex::GetNeighbors(const float *query, fullset.push_back(resset[i]); /////////////////////////////////////// } - std::sort(resset.begin(), resset.end()); // sort by distance + std::sort(resset.begin(), resset.end()); // sort by distance - //search nearest neighbor + // search nearest neighbor size_t cursor = 0; while (cursor < buffer_size) { size_t nearest_updated_pos = buffer_size; @@ -193,36 +198,42 @@ void NsgIndex::GetNeighbors(const float *query, resset[cursor].has_explored = true; node_t start_pos = resset[cursor].id; - auto &wait_for_search_node_vec = graph[start_pos]; + auto& wait_for_search_node_vec = graph[start_pos]; for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) { node_t id = wait_for_search_node_vec[i]; - if (has_calculated_dist[id]) continue; + if (has_calculated_dist[id]) + continue; has_calculated_dist[id] = true; - float - dist = calculate(query, ori_data_ + dimension * id, dimension); + float dist = calculate(query, ori_data_ + dimension * id, dimension); Neighbor nn(id, dist, false); fullset.push_back(nn); - if (dist >= resset[buffer_size - 1].distance) continue; + if (dist >= resset[buffer_size - 1].distance) + continue; - size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node - if (pos < nearest_updated_pos) nearest_updated_pos = pos; + size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node + if (pos < nearest_updated_pos) + nearest_updated_pos = pos; - //assert(buffer_size + 1 >= resset.size()); - if (buffer_size + 1 < resset.size()) ++buffer_size; + // assert(buffer_size + 1 >= resset.size()); + if (buffer_size + 1 < resset.size()) + ++buffer_size; } } if (cursor >= nearest_updated_pos) { - cursor = nearest_updated_pos; // re-search from new pos - } else ++cursor; + cursor = nearest_updated_pos; // re-search from new pos + } else { + ++cursor; + } } } } // FindUnconnectedNode -void NsgIndex::GetNeighbors(const float *query, std::vector &resset, std::vector &fullset) { - auto &graph = nsg; +void +NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::vector& fullset) { + auto& graph = nsg; size_t buffer_size = search_length; if (buffer_size > ntotal) { @@ -230,7 +241,7 @@ void NsgIndex::GetNeighbors(const float *query, std::vector &resset, s } std::vector init_ids; - boost::dynamic_bitset<> has_calculated_dist{ntotal, 0}; // TODO: ? + boost::dynamic_bitset<> has_calculated_dist{ntotal, 0}; // TODO: ? { /* @@ -244,9 +255,11 @@ void NsgIndex::GetNeighbors(const float *query, std::vector &resset, s has_calculated_dist[init_ids[i]] = true; ++count; } + unsigned int seed = 100; while (count < buffer_size) { - node_t id = rand() % ntotal; - if (has_calculated_dist[id]) continue; // duplicate id + node_t id = rand_r(&seed) % ntotal; + if (has_calculated_dist[id]) + continue; // duplicate id init_ids.push_back(id); ++count; has_calculated_dist[id] = true; @@ -260,7 +273,7 @@ void NsgIndex::GetNeighbors(const float *query, std::vector &resset, s for (size_t i = 0; i < init_ids.size(); ++i) { node_t id = init_ids[i]; - if (id >= ntotal) { + if (id >= static_cast(ntotal)) { KNOWHERE_THROW_MSG("Build Index Error, id > ntotal"); continue; } @@ -268,7 +281,7 @@ void NsgIndex::GetNeighbors(const float *query, std::vector &resset, s float dist = calculate(ori_data_ + id * dimension, query, dimension); resset[i] = Neighbor(id, dist, false); } - std::sort(resset.begin(), resset.end()); // sort by distance + std::sort(resset.begin(), resset.end()); // sort by distance // search nearest neighbor size_t cursor = 0; @@ -279,38 +292,41 @@ void NsgIndex::GetNeighbors(const float *query, std::vector &resset, s resset[cursor].has_explored = true; node_t start_pos = resset[cursor].id; - auto &wait_for_search_node_vec = graph[start_pos]; + auto& wait_for_search_node_vec = graph[start_pos]; for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) { node_t id = wait_for_search_node_vec[i]; - if (has_calculated_dist[id]) continue; + if (has_calculated_dist[id]) + continue; has_calculated_dist[id] = true; - float - dist = calculate(ori_data_ + dimension * id, query, dimension); + float dist = calculate(ori_data_ + dimension * id, query, dimension); Neighbor nn(id, dist, false); fullset.push_back(nn); - if (dist >= resset[buffer_size - 1].distance) continue; + if (dist >= resset[buffer_size - 1].distance) + continue; - size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node - if (pos < nearest_updated_pos) nearest_updated_pos = pos; + size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node + if (pos < nearest_updated_pos) + nearest_updated_pos = pos; - //assert(buffer_size + 1 >= resset.size()); - if (buffer_size + 1 < resset.size()) ++buffer_size; // trick + // assert(buffer_size + 1 >= resset.size()); + if (buffer_size + 1 < resset.size()) + ++buffer_size; // trick } } if (cursor >= nearest_updated_pos) { - cursor = nearest_updated_pos; // re-search from new pos - } else ++cursor; + cursor = nearest_updated_pos; // re-search from new pos + } else { + ++cursor; + } } } } -void NsgIndex::GetNeighbors(const float *query, - std::vector &resset, - Graph &graph, - SearchParams *params) { - size_t &buffer_size = params ? params->search_length : search_length; +void +NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& graph, SearchParams* params) { + size_t& buffer_size = params ? params->search_length : search_length; if (buffer_size > ntotal) { // TODO: throw exception here. @@ -331,9 +347,11 @@ void NsgIndex::GetNeighbors(const float *query, has_calculated_dist[init_ids[i]] = true; ++count; } + unsigned int seed = 100; while (count < buffer_size) { - node_t id = rand() % ntotal; - if (has_calculated_dist[id]) continue; // duplicate id + node_t id = rand_r(&seed) % ntotal; + if (has_calculated_dist[id]) + continue; // duplicate id init_ids.push_back(id); ++count; has_calculated_dist[id] = true; @@ -347,8 +365,8 @@ void NsgIndex::GetNeighbors(const float *query, for (size_t i = 0; i < init_ids.size(); ++i) { node_t id = init_ids[i]; - //assert(id < ntotal); - if (id >= ntotal) { + // assert(id < ntotal); + if (id >= static_cast(ntotal)) { KNOWHERE_THROW_MSG("Build Index Error, id > ntotal"); continue; } @@ -356,11 +374,11 @@ void NsgIndex::GetNeighbors(const float *query, float dist = calculate(ori_data_ + id * dimension, query, dimension); resset[i] = Neighbor(id, dist, false); } - std::sort(resset.begin(), resset.end()); // sort by distance + std::sort(resset.begin(), resset.end()); // sort by distance //>> Debug code ///// - //for (int j = 0; j < buffer_size; ++j) { + // for (int j = 0; j < buffer_size; ++j) { // std::cout << "resset_id: " << resset[j].id << ", resset_dist: " << resset[j].distance << std::endl; //} ///// @@ -374,41 +392,47 @@ void NsgIndex::GetNeighbors(const float *query, resset[cursor].has_explored = true; node_t start_pos = resset[cursor].id; - auto &wait_for_search_node_vec = graph[start_pos]; + auto& wait_for_search_node_vec = graph[start_pos]; for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) { node_t id = wait_for_search_node_vec[i]; - if (has_calculated_dist[id]) continue; + if (has_calculated_dist[id]) + continue; has_calculated_dist[id] = true; - float - dist = calculate(query, ori_data_ + dimension * id, dimension); + float dist = calculate(query, ori_data_ + dimension * id, dimension); - if (dist >= resset[buffer_size - 1].distance) continue; + if (dist >= resset[buffer_size - 1].distance) + continue; ///////////// difference from other GetNeighbors /////////////// Neighbor nn(id, dist, false); /////////////////////////////////////// - size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node - if (pos < nearest_updated_pos) nearest_updated_pos = pos; + size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node + if (pos < nearest_updated_pos) + nearest_updated_pos = pos; //>> Debug code ///// - //std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " << nearest_updated_pos << std::endl; + // std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " << + // nearest_updated_pos << std::endl; ///// - // trick: avoid search query search_length < init_ids.size() ... - if (buffer_size + 1 < resset.size()) ++buffer_size; + if (buffer_size + 1 < resset.size()) + ++buffer_size; } } if (cursor >= nearest_updated_pos) { - cursor = nearest_updated_pos; // re-search from new pos - } else ++cursor; + cursor = nearest_updated_pos; // re-search from new pos + } else { + ++cursor; + } } } } -void NsgIndex::Link() { +void +NsgIndex::Link() { auto cut_graph_dist = new float[ntotal * out_degree]; nsg.resize(ntotal); @@ -416,7 +440,7 @@ void NsgIndex::Link() { { std::vector fullset; std::vector temp; - boost::dynamic_bitset<> flags{ntotal, 0}; // TODO: ? + boost::dynamic_bitset<> flags{ntotal, 0}; // TODO: ? #pragma omp for schedule(dynamic, 100) for (size_t n = 0; n < ntotal; ++n) { fullset.clear(); @@ -425,8 +449,8 @@ void NsgIndex::Link() { //>> Debug code ///// - //float r1 = calculate(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension); - //assert(r1 == temp[0].distance); + // float r1 = calculate(ori_data_ + n * dimension, ori_data_ + temp[0].id * dimension, dimension); + // assert(r1 == temp[0].distance); ///// SyncPrune(n, fullset, flags, cut_graph_dist); } @@ -434,7 +458,7 @@ void NsgIndex::Link() { //>> Debug code ///// - //auto bak_nsg = nsg; + // auto bak_nsg = nsg; ///// knng.clear(); @@ -450,8 +474,8 @@ void NsgIndex::Link() { //>> Debug code ///// - //int count = 0; - //for (int i = 0; i < ntotal; ++i) { + // int count = 0; + // for (int i = 0; i < ntotal; ++i) { // if (bak_nsg[i].size() != nsg[i].size()) { // //count += nsg[i].size() - bak_nsg[i].size(); // count += nsg[i].size(); @@ -459,21 +483,20 @@ void NsgIndex::Link() { //} ///// - for (int i = 0; i < ntotal; ++i) { + for (size_t i = 0; i < ntotal; ++i) { nsg[i].shrink_to_fit(); } } -void NsgIndex::SyncPrune(size_t n, - std::vector &pool, - boost::dynamic_bitset<> &has_calculated, - float *cut_graph_dist) { +void +NsgIndex::SyncPrune(size_t n, std::vector& pool, boost::dynamic_bitset<>& has_calculated, + float* cut_graph_dist) { // avoid lose nearest neighbor in knng for (size_t i = 0; i < knng[n].size(); ++i) { auto id = knng[n][i]; - if (has_calculated[id]) continue; - float dist = calculate(ori_data_ + dimension * n, - ori_data_ + dimension * id, dimension); + if (has_calculated[id]) + continue; + float dist = calculate(ori_data_ + dimension * n, ori_data_ + dimension * id, dimension); pool.emplace_back(Neighbor(id, dist, true)); } @@ -481,14 +504,16 @@ void NsgIndex::SyncPrune(size_t n, unsigned cursor = 0; std::sort(pool.begin(), pool.end()); std::vector result; - if (pool[cursor].id == n) cursor++; - result.push_back(pool[cursor]); // init result with nearest neighbor + if (pool[cursor].id == static_cast(n)) { + cursor++; + } + result.push_back(pool[cursor]); // init result with nearest neighbor SelectEdge(cursor, pool, result, true); // filling the cut_graph - auto &des_id_pool = nsg[n]; - float *des_dist_pool = cut_graph_dist + n * out_degree; + auto& des_id_pool = nsg[n]; + float* des_dist_pool = cut_graph_dist + n * out_degree; for (size_t i = 0; i < result.size(); ++i) { des_id_pool.push_back(result[i].id); des_dist_pool[i] = result[i].distance; @@ -500,24 +525,27 @@ void NsgIndex::SyncPrune(size_t n, } //>> Optimize: remove read-lock -void NsgIndex::InterInsert(unsigned n, std::vector &mutex_vec, float *cut_graph_dist) { - auto ¤t = n; +void +NsgIndex::InterInsert(unsigned n, std::vector& mutex_vec, float* cut_graph_dist) { + auto& current = n; - auto &neighbor_id_pool = nsg[current]; - float *neighbor_dist_pool = cut_graph_dist + current * out_degree; + auto& neighbor_id_pool = nsg[current]; + float* neighbor_dist_pool = cut_graph_dist + current * out_degree; for (size_t i = 0; i < out_degree; ++i) { - if (neighbor_dist_pool[i] == -1) break; + if (neighbor_dist_pool[i] == -1) + break; - size_t current_neighbor = neighbor_id_pool[i]; // center's neighbor id - auto &nsn_id_pool = nsg[current_neighbor]; // nsn => neighbor's neighbor - float *nsn_dist_pool = cut_graph_dist + current_neighbor * out_degree; + size_t current_neighbor = neighbor_id_pool[i]; // center's neighbor id + auto& nsn_id_pool = nsg[current_neighbor]; // nsn => neighbor's neighbor + float* nsn_dist_pool = cut_graph_dist + current_neighbor * out_degree; - std::vector wait_for_link_pool; // maintain candidate neighbor of the current neighbor. + std::vector wait_for_link_pool; // maintain candidate neighbor of the current neighbor. int duplicate = false; { LockGuard lk(mutex_vec[current_neighbor]); - for (int j = 0; j < out_degree; ++j) { - if (nsn_dist_pool[j] == -1) break; + for (size_t j = 0; j < out_degree; ++j) { + if (nsn_dist_pool[j] == -1) + break; // 保证至少有一条边能连回来 if (n == nsn_id_pool[j]) { @@ -529,7 +557,8 @@ void NsgIndex::InterInsert(unsigned n, std::vector &mutex_vec, float wait_for_link_pool.push_back(nsn); } } - if (duplicate) continue; + if (duplicate) + continue; // original: (neighbor) <------- (current) // after: (neighbor) -------> (current) @@ -549,31 +578,29 @@ void NsgIndex::InterInsert(unsigned n, std::vector &mutex_vec, float { LockGuard lk(mutex_vec[current_neighbor]); - for (int j = 0; j < result.size(); ++j) { + for (size_t j = 0; j < result.size(); ++j) { nsn_id_pool[j] = result[j].id; nsn_dist_pool[j] = result[j].distance; } } } else { LockGuard lk(mutex_vec[current_neighbor]); - for (int j = 0; j < out_degree; ++j) { + for (size_t j = 0; j < out_degree; ++j) { if (nsn_dist_pool[j] == -1) { nsn_id_pool.push_back(current_as_neighbor.id); nsn_dist_pool[j] = current_as_neighbor.distance; - if (j + 1 < out_degree) nsn_dist_pool[j + 1] = -1; + if (j + 1 < out_degree) + nsn_dist_pool[j + 1] = -1; break; } } } - } } -void NsgIndex::SelectEdge(unsigned &cursor, - std::vector &sort_pool, - std::vector &result, - bool limit) { - auto &pool = sort_pool; +void +NsgIndex::SelectEdge(unsigned& cursor, std::vector& sort_pool, std::vector& result, bool limit) { + auto& pool = sort_pool; /* * edge selection @@ -583,55 +610,59 @@ void NsgIndex::SelectEdge(unsigned &cursor, */ size_t search_deepth = limit ? candidate_pool_size : pool.size(); while (result.size() < out_degree && cursor < search_deepth && (++cursor) < pool.size()) { - auto &p = pool[cursor]; + auto& p = pool[cursor]; bool should_link = true; for (size_t t = 0; t < result.size(); ++t) { - float dist = calculate(ori_data_ + dimension * result[t].id, - ori_data_ + dimension * p.id, dimension); + float dist = calculate(ori_data_ + dimension * result[t].id, ori_data_ + dimension * p.id, dimension); if (dist < p.distance) { should_link = false; break; } } - if (should_link) result.push_back(p); + if (should_link) + result.push_back(p); } } -void NsgIndex::CheckConnectivity() { +void +NsgIndex::CheckConnectivity() { auto root = navigation_point; boost::dynamic_bitset<> has_linked{ntotal, 0}; int64_t linked_count = 0; - while (linked_count < ntotal) { + while (linked_count < static_cast(ntotal)) { DFS(root, has_linked, linked_count); - if (linked_count >= ntotal) break; + if (linked_count >= static_cast(ntotal)) { + break; + } FindUnconnectedNode(has_linked, root); } } -void NsgIndex::DFS(size_t root, boost::dynamic_bitset<> &has_linked, int64_t &linked_count) { +void +NsgIndex::DFS(size_t root, boost::dynamic_bitset<>& has_linked, int64_t& linked_count) { size_t start = root; std::stack s; s.push(root); if (!has_linked[root]) { - linked_count++; // not link - has_linked[root] = true; // link start... + linked_count++; // not link + has_linked[root] = true; // link start... } while (!s.empty()) { size_t next = ntotal + 1; for (unsigned i = 0; i < nsg[start].size(); i++) { - if (has_linked[nsg[start][i]] == false) // if not link - { + if (has_linked[nsg[start][i]] == false) { // if not link next = nsg[start][i]; break; } } if (next == (ntotal + 1)) { s.pop(); - if (s.empty()) break; + if (s.empty()) + break; start = s.top(); continue; } @@ -642,17 +673,19 @@ void NsgIndex::DFS(size_t root, boost::dynamic_bitset<> &has_linked, int64_t &li } } -void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t &root) { +void +NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root) { // find any of unlinked-node size_t id = ntotal; - for (size_t i = 0; i < ntotal; i++) { // find not link + for (size_t i = 0; i < ntotal; i++) { // find not link if (has_linked[i] == false) { id = i; break; } } - if (id == ntotal) return; // No Unlinked Node + if (id == ntotal) + return; // No Unlinked Node // search unlinked-node's neighbor std::vector tmp, pool; @@ -660,7 +693,7 @@ void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t std::sort(pool.begin(), pool.end()); size_t found = 0; - for (size_t i = 0; i < pool.size(); i++) { // find nearest neighbor and add unlinked-node as its neighbor + for (size_t i = 0; i < pool.size(); i++) { // find nearest neighbor and add unlinked-node as its neighbor if (has_linked[pool[i].id]) { root = pool[i].id; found = 1; @@ -668,8 +701,9 @@ void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t } } if (found == 0) { - while (true) { // random a linked-node and add unlinked-node as its neighbor - size_t rid = rand() % ntotal; + unsigned int seed = 100; + while (true) { // random a linked-node and add unlinked-node as its neighbor + size_t rid = rand_r(&seed) % ntotal; if (has_linked[rid]) { root = rid; break; @@ -679,23 +713,18 @@ void NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<> &has_linked, int64_t nsg[root].push_back(id); } - -void NsgIndex::Search(const float *query, - const unsigned &nq, - const unsigned &dim, - const unsigned &k, - float *dist, - long *ids, - SearchParams ¶ms) { +void +NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, + int64_t* ids, SearchParams& params) { std::vector> resset(nq); TimeRecorder rc("search"); if (nq == 1) { GetNeighbors(query, resset[0], nsg, ¶ms); - } else{ - //#pragma omp parallel for schedule(dynamic, 50) - #pragma omp parallel for - for (int i = 0; i < nq; ++i) { + } else { +//#pragma omp parallel for schedule(dynamic, 50) +#pragma omp parallel for + for (unsigned int i = 0; i < nq; ++i) { // TODO(linxj): when to use openmp auto single_query = query + i * dim; GetNeighbors(single_query, resset[i], nsg, ¶ms); @@ -703,9 +732,9 @@ void NsgIndex::Search(const float *query, } rc.ElapseFromBegin("cost"); - for (int i = 0; i < nq; ++i) { - for (int j = 0; j < k; ++j) { - //ids[i * k + j] = resset[i][j].id; + for (unsigned int i = 0; i < nq; ++i) { + for (unsigned int j = 0; j < k; ++j) { + // ids[i * k + j] = resset[i][j].id; // Fix(linxj): bug, reset[i][j] out of range ids[i * k + j] = ids_[resset[i][j].id]; @@ -714,27 +743,28 @@ void NsgIndex::Search(const float *query, } //>> Debug: test single insert - //int x_0 = resset[0].size(); - //for (int l = 0; l < resset[0].size(); ++l) { + // int x_0 = resset[0].size(); + // for (int l = 0; l < resset[0].size(); ++l) { // resset[0].pop_back(); //} - //resset.clear(); + // resset.clear(); - //ProfilerStart("xx.prof"); - //std::vector resset; - //GetNeighbors(query, resset, nsg, ¶ms); - //for (int i = 0; i < k; ++i) { + // ProfilerStart("xx.prof"); + // std::vector resset; + // GetNeighbors(query, resset, nsg, ¶ms); + // for (int i = 0; i < k; ++i) { // ids[i] = resset[i].id; - //dist[i] = resset[i].distance; + // dist[i] = resset[i].distance; //} - //ProfilerStop(); + // ProfilerStop(); } -void NsgIndex::SetKnnGraph(Graph &g) { +void +NsgIndex::SetKnnGraph(Graph& g) { knng = std::move(g); } -//void NsgIndex::GetKnnGraphFromFile() { +// void NsgIndex::GetKnnGraphFromFile() { // //std::string filename = "/home/zilliz/opt/workspace/wook/efanna_graph/tests/sift.1M.50NN.graph"; // std::string filename = "/home/zilliz/opt/workspace/wook/efanna_graph/tests/sift.50NN.graph"; // @@ -759,6 +789,5 @@ void NsgIndex::SetKnnGraph(Graph &g) { // in.close(); //} -} -} -} +} // namespace algo +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.h b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.h index 5291c17280a387f8540c6e752d3b178f90e96c88..160c076e45e717948f55e24352b82c69ab5b4b14 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSG.h @@ -15,22 +15,18 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include -#include #include +#include #include #include "Neighbor.h" - -namespace zilliz { namespace knowhere { namespace algo { - using node_t = int64_t; enum class MetricType { @@ -53,15 +49,15 @@ using Graph = std::vector>; class NsgIndex { public: size_t dimension; - size_t ntotal; // totabl nb of indexed vectors - MetricType metric_type; // L2 | IP + size_t ntotal; // totabl nb of indexed vectors + MetricType metric_type; // L2 | IP - float *ori_data_; - long *ids_; // TODO: support different type - Graph nsg; // final graph - Graph knng; // reset after build + float* ori_data_; + int64_t* ids_; // TODO: support different type + Graph nsg; // final graph + Graph knng; // reset after build - node_t navigation_point; // offset of node in origin data + node_t navigation_point; // offset of node in origin data bool is_trained = false; @@ -69,91 +65,80 @@ class NsgIndex { * build and search parameter */ size_t search_length; - size_t candidate_pool_size; // search deepth in fullset + size_t candidate_pool_size; // search deepth in fullset size_t out_degree; public: - explicit NsgIndex(const size_t &dimension, - const size_t &n, - MetricType metric = MetricType::METRIC_L2); + explicit NsgIndex(const size_t& dimension, const size_t& n, MetricType metric = MetricType::METRIC_L2); NsgIndex() = default; virtual ~NsgIndex(); - void SetKnnGraph(Graph &knng); + void + SetKnnGraph(Graph& knng); - virtual void Build_with_ids(size_t nb, - const float *data, - const long *ids, - const BuildParams ¶meters); + virtual void + Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters); - void Search(const float *query, - const unsigned &nq, - const unsigned &dim, - const unsigned &k, - float *dist, - long *ids, - SearchParams ¶ms); + void + Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, int64_t* ids, + SearchParams& params); // Not support yet. - //virtual void Add() = 0; - //virtual void Add_with_ids() = 0; - //virtual void Delete() = 0; - //virtual void Delete_with_ids() = 0; - //virtual void Rebuild(size_t nb, + // virtual void Add() = 0; + // virtual void Add_with_ids() = 0; + // virtual void Delete() = 0; + // virtual void Delete_with_ids() = 0; + // virtual void Rebuild(size_t nb, // const float *data, - // const long *ids, + // const int64_t *ids, // const Parameters ¶meters) = 0; - //virtual void Build(size_t nb, + // virtual void Build(size_t nb, // const float *data, // const BuildParam ¶meters); protected: - virtual void InitNavigationPoint(); + virtual void + InitNavigationPoint(); // link specify - void GetNeighbors(const float *query, - std::vector &resset, - std::vector &fullset, - boost::dynamic_bitset<> &has_calculated_dist); + void + GetNeighbors(const float* query, std::vector& resset, std::vector& fullset, + boost::dynamic_bitset<>& has_calculated_dist); // FindUnconnectedNode - void GetNeighbors(const float *query, - std::vector &resset, - std::vector &fullset); + void + GetNeighbors(const float* query, std::vector& resset, std::vector& fullset); // search and navigation-point - void GetNeighbors(const float *query, - std::vector &resset, - Graph &graph, - SearchParams *param = nullptr); + void + GetNeighbors(const float* query, std::vector& resset, Graph& graph, SearchParams* param = nullptr); - void Link(); + void + Link(); - void SyncPrune(size_t q, - std::vector &pool, - boost::dynamic_bitset<> &has_calculated, - float *cut_graph_dist - ); + void + SyncPrune(size_t q, std::vector& pool, boost::dynamic_bitset<>& has_calculated, float* cut_graph_dist); - void SelectEdge(unsigned &cursor, - std::vector &sort_pool, - std::vector &result, - bool limit = false); + void + SelectEdge(unsigned& cursor, std::vector& sort_pool, std::vector& result, bool limit = false); - void InterInsert(unsigned n, std::vector &mutex_vec, float *dist); + void + InterInsert(unsigned n, std::vector& mutex_vec, float* dist); - void CheckConnectivity(); + void + CheckConnectivity(); - void DFS(size_t root, boost::dynamic_bitset<> &flags, int64_t &count); + void + DFS(size_t root, boost::dynamic_bitset<>& flags, int64_t& count); - void FindUnconnectedNode(boost::dynamic_bitset<> &flags, int64_t &root); + void + FindUnconnectedNode(boost::dynamic_bitset<>& flags, int64_t& root); - //private: - // void GetKnnGraphFromFile(); + // private: + // void GetKnnGraphFromFile(); }; -} -} -} +} // namespace algo +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp index eb87f97f3edcdbfa7a0b6557687e7ca704d2b370..05e8d18787ed09c608f71d893319659959871b99 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.cpp @@ -15,21 +15,19 @@ // specific language governing permissions and limitations // under the License. - #include #include -#include "NSGHelper.h" - +#include "knowhere/index/vector_index/nsg/NSGHelper.h" -namespace zilliz { namespace knowhere { namespace algo { // TODO: impl search && insert && return insert pos. why not just find and swap? -int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) { +int +InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) { //>> Fix: Add assert - for (int i = 0; i < K; ++i) { + for (unsigned int i = 0; i < K; ++i) { assert(addr[i].id != nn.id); } @@ -37,7 +35,7 @@ int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) { int left = 0, right = K - 1; if (addr[left].distance > nn.distance) { //>> Fix: memmove overflow, dump when vector deconstruct - memmove((char *) &addr[left + 1], &addr[left], (K - 1) * sizeof(Neighbor)); + memmove((char*)&addr[left + 1], &addr[left], (K - 1) * sizeof(Neighbor)); addr[left] = nn; return left; } @@ -52,10 +50,10 @@ int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) { else left = mid; } - //check equal ID + // check equal ID while (left > 0) { - if (addr[left].distance < nn.distance) // pos is right + if (addr[left].distance < nn.distance) // pos is right break; if (addr[left].id == nn.id) return K + 1; @@ -65,24 +63,25 @@ int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn) { return K + 1; //>> Fix: memmove overflow, dump when vector deconstruct - memmove((char *) &addr[right + 1], &addr[right], (K - 1 - right) * sizeof(Neighbor)); + memmove((char*)&addr[right + 1], &addr[right], (K - 1 - right) * sizeof(Neighbor)); addr[right] = nn; return right; } // TODO: support L2 / IP -float calculate(const float *a, const float *b, unsigned size) { +float +calculate(const float* a, const float* b, unsigned size) { float result = 0; #ifdef __GNUC__ #ifdef __AVX__ #define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm256_loadu_ps(addr1);\ - tmp2 = _mm256_loadu_ps(addr2);\ - tmp1 = _mm256_sub_ps(tmp1, tmp2); \ - tmp1 = _mm256_mul_ps(tmp1, tmp1); \ - dest = _mm256_add_ps(dest, tmp1); + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_sub_ps(tmp1, tmp2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp1); \ + dest = _mm256_add_ps(dest, tmp1); __m256 sum; __m256 l0, l1; @@ -90,14 +89,16 @@ float calculate(const float *a, const float *b, unsigned size) { unsigned D = (size + 7) & ~7U; unsigned DR = D % 16; unsigned DD = D - DR; - const float *l = a; - const float *r = b; - const float *e_l = l + DD; - const float *e_r = r + DD; - float unpack[8] __attribute__ ((aligned (32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; sum = _mm256_loadu_ps(unpack); - if (DR) { AVX_L2SQR(e_l, e_r, sum, l0, r0); } + if (DR) { + AVX_L2SQR(e_l, e_r, sum, l0, r0); + } for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { AVX_L2SQR(l, r, sum, l0, r0); @@ -109,11 +110,11 @@ float calculate(const float *a, const float *b, unsigned size) { #else #ifdef __SSE2__ #define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ - tmp1 = _mm_load_ps(addr1);\ - tmp2 = _mm_load_ps(addr2);\ - tmp1 = _mm_sub_ps(tmp1, tmp2); \ - tmp1 = _mm_mul_ps(tmp1, tmp1); \ - dest = _mm_add_ps(dest, tmp1); + tmp1 = _mm_load_ps(addr1); \ + tmp2 = _mm_load_ps(addr2); \ + tmp1 = _mm_sub_ps(tmp1, tmp2); \ + tmp1 = _mm_mul_ps(tmp1, tmp1); \ + dest = _mm_add_ps(dest, tmp1); __m128 sum; __m128 l0, l1, l2, l3; @@ -121,18 +122,22 @@ float calculate(const float *a, const float *b, unsigned size) { unsigned D = (size + 3) & ~3U; unsigned DR = D % 16; unsigned DD = D - DR; - const float *l = a; - const float *r = b; - const float *e_l = l + DD; - const float *e_r = r + DD; - float unpack[4] __attribute__ ((aligned (16))) = {0, 0, 0, 0}; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; sum = _mm_load_ps(unpack); switch (DR) { - case 12:SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2); - case 8:SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1); - case 4:SSE_L2SQR(e_l, e_r, sum, l0, r0); - default:break; + case 12: + SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2); + case 8: + SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1); + case 4: + SSE_L2SQR(e_l, e_r, sum, l0, r0); + default: + break; } for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { SSE_L2SQR(l, r, sum, l0, r0); @@ -143,28 +148,28 @@ float calculate(const float *a, const float *b, unsigned size) { _mm_storeu_ps(unpack, sum); result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; -//nomal distance +// nomal distance #else float diff0, diff1, diff2, diff3; - const float* last = a + size; - const float* unroll_group = last - 3; - - /* Process 4 items with each loop for efficiency. */ - while (a < unroll_group) { - diff0 = a[0] - b[0]; - diff1 = a[1] - b[1]; - diff2 = a[2] - b[2]; - diff3 = a[3] - b[3]; - result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; - a += 4; - b += 4; - } - /* Process last 0-3 pixels. Not needed for standard vector lengths. */ - while (a < last) { - diff0 = *a++ - *b++; - result += diff0 * diff0; - } + const float* last = a + size; + const float* unroll_group = last - 3; + + /* Process 4 items with each loop for efficiency. */ + while (a < unroll_group) { + diff0 = a[0] - b[0]; + diff1 = a[1] - b[1]; + diff2 = a[2] - b[2]; + diff3 = a[3] - b[3]; + result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; + a += 4; + b += 4; + } + /* Process last 0-3 pixels. Not needed for standard vector lengths. */ + while (a < last) { + diff0 = *a++ - *b++; + result += diff0 * diff0; + } #endif #endif #endif @@ -172,7 +177,5 @@ float calculate(const float *a, const float *b, unsigned size) { return result; } - -} -} -} \ No newline at end of file +} // namespace algo +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h index f5c13194c403033acd0ac9e0be6917516f41f07c..5007cf019c016f6040b779e65deb7de461d26654 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGHelper.h @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include @@ -26,14 +25,13 @@ #include "NSG.h" #include "knowhere/common/Config.h" - -namespace zilliz { namespace knowhere { namespace algo { -extern int InsertIntoPool(Neighbor *addr, unsigned K, Neighbor nn); -extern float calculate(const float *a, const float *b, unsigned size); +extern int +InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn); +extern float +calculate(const float* a, const float* b, unsigned size); -} -} -} +} // namespace algo +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.cpp b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.cpp index bcdde8052c351e42c32e72feaed24711e8d0dbee..cac3b5864f22675a3e7e695adc40f2e9afc907d3 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.cpp +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.cpp @@ -15,31 +15,30 @@ // specific language governing permissions and limitations // under the License. - #include -#include "NSGIO.h" - +#include "knowhere/index/vector_index/nsg/NSGIO.h" -namespace zilliz { namespace knowhere { namespace algo { -void write_index(NsgIndex *index, MemoryIOWriter &writer) { +void +write_index(NsgIndex* index, MemoryIOWriter& writer) { writer(&index->ntotal, sizeof(index->ntotal), 1); writer(&index->dimension, sizeof(index->dimension), 1); writer(&index->navigation_point, sizeof(index->navigation_point), 1); writer(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); - writer(index->ids_, sizeof(long) * index->ntotal, 1); + writer(index->ids_, sizeof(int64_t) * index->ntotal, 1); for (unsigned i = 0; i < index->ntotal; ++i) { - auto neighbor_num = (node_t) index->nsg[i].size(); + auto neighbor_num = (node_t)index->nsg[i].size(); writer(&neighbor_num, sizeof(node_t), 1); writer(index->nsg[i].data(), neighbor_num * sizeof(node_t), 1); } } -NsgIndex *read_index(MemoryIOReader &reader) { +NsgIndex* +read_index(MemoryIOReader& reader) { size_t ntotal; size_t dimension; reader(&ntotal, sizeof(size_t), 1); @@ -48,9 +47,9 @@ NsgIndex *read_index(MemoryIOReader &reader) { reader(&index->navigation_point, sizeof(index->navigation_point), 1); index->ori_data_ = new float[index->ntotal * index->dimension]; - index->ids_ = new long[index->ntotal]; + index->ids_ = new int64_t[index->ntotal]; reader(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); - reader(index->ids_, sizeof(long) * index->ntotal, 1); + reader(index->ids_, sizeof(int64_t) * index->ntotal, 1); index->nsg.reserve(index->ntotal); index->nsg.resize(index->ntotal); @@ -66,6 +65,5 @@ NsgIndex *read_index(MemoryIOReader &reader) { return index; } -} -} -} +} // namespace algo +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.h b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.h index 3d6786c6c236278764a8ed5b418180f038a80275..12913b69df754824a60f2c08b7c6745bde0a65ef 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/NSGIO.h @@ -15,21 +15,19 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "knowhere/index/vector_index/helpers/FaissIO.h" #include "NSG.h" #include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" - -namespace zilliz { namespace knowhere { namespace algo { -extern void write_index(NsgIndex* index, MemoryIOWriter& writer); -extern NsgIndex* read_index(MemoryIOReader& reader); +extern void +write_index(NsgIndex* index, MemoryIOWriter& writer); +extern NsgIndex* +read_index(MemoryIOReader& reader); -} -} -} +} // namespace algo +} // namespace knowhere diff --git a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/Neighbor.h b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/Neighbor.h index 9aceb62692ccd875925a5844d233128eb790b66c..c3a314164c18797daacd7179bd6f40ef347e8b30 100644 --- a/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/Neighbor.h +++ b/cpp/src/core/knowhere/knowhere/index/vector_index/nsg/Neighbor.h @@ -15,13 +15,10 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include - -namespace zilliz { namespace knowhere { namespace algo { @@ -29,21 +26,25 @@ using node_t = int64_t; // TODO: search use simple neighbor struct Neighbor { - node_t id; // offset of node in origin data + node_t id; // offset of node in origin data float distance; bool has_explored; Neighbor() = default; - explicit Neighbor(node_t id, float distance, bool f) : id{id}, distance{distance}, has_explored(f) {} - explicit Neighbor(node_t id, float distance) : id{id}, distance{distance}, has_explored(false) {} + explicit Neighbor(node_t id, float distance, bool f) : id{id}, distance{distance}, has_explored(f) { + } + + explicit Neighbor(node_t id, float distance) : id{id}, distance{distance}, has_explored(false) { + } - inline bool operator<(const Neighbor &other) const { + inline bool + operator<(const Neighbor& other) const { return distance < other.distance; } }; -//struct SimpleNeighbor { +// struct SimpleNeighbor { // node_t id; // offset of node in origin data // float distance; // @@ -57,7 +58,5 @@ struct Neighbor { typedef std::lock_guard LockGuard; - -} -} -} \ No newline at end of file +} // namespace algo +} // namespace knowhere diff --git a/cpp/src/core/test/SPTAG.cpp b/cpp/src/core/test/SPTAG.cpp deleted file mode 100644 index 3dfea0f088428481e142b3cc7408a48b78ad480d..0000000000000000000000000000000000000000 --- a/cpp/src/core/test/SPTAG.cpp +++ /dev/null @@ -1,36 +0,0 @@ - -#include -#include -#include -#include "SPTAG/AnnService/inc/Core/Common.h" -#include "SPTAG/AnnService/inc/Core/VectorIndex.h" - - -int -main(int argc, char *argv[]) { - using namespace SPTAG; - const int d = 128; - const int n = 100; - - auto p_data = new float[n * d]; - - auto index = VectorIndex::CreateInstance(IndexAlgoType::KDT, VectorValueType::Float); - - std::random_device rd; - std::mt19937 mt(rd()); - std::uniform_real_distribution dist(1.0, 2.0); - - for (auto i = 0; i < n; i++) { - for (auto j = 0; j < d; j++) { - p_data[i * d + j] = dist(mt) - 1; - } - } - std::cout << "generate random n * d finished."; - ByteArray data((uint8_t *) p_data, n * d * sizeof(float), true); - - auto vectorset = std::make_shared(data, VectorValueType::Float, d, n); - index->BuildIndex(vectorset, nullptr); - - std::cout << index->GetFeatureDim(); -} - diff --git a/cpp/src/core/test/kdtree.cpp b/cpp/src/core/test/kdtree.cpp deleted file mode 100644 index 0b63e657e1f465a1dffef4e76afb3d30e6b6367f..0000000000000000000000000000000000000000 --- a/cpp/src/core/test/kdtree.cpp +++ /dev/null @@ -1,134 +0,0 @@ - -#include -#include -#include "knowhere/index/vector_index/cpu_kdt_rng.h" -#include "knowhere/index/vector_index/definitions.h" -#include "knowhere/adapter/sptag.h" -#include "knowhere/adapter/structure.h" - - -using namespace zilliz::knowhere; - -DatasetPtr -generate_dataset(int64_t n, int64_t d, int64_t base) { - auto elems = n * d; - auto p_data = (float *) malloc(elems * sizeof(float)); - auto p_id = (int64_t *) malloc(elems * sizeof(int64_t)); - assert(p_data != nullptr && p_id != nullptr); - - for (auto i = 0; i < n; ++i) { - for (auto j = 0; j < d; ++j) { - p_data[i * d + j] = float(base + i); - } - p_id[i] = i; - } - - std::vector shape{n, d}; - auto tensor = ConstructFloatTensorSmart((uint8_t *) p_data, elems * sizeof(float), shape); - std::vector tensors{tensor}; - std::vector tensor_fields{ConstructFloatField("data")}; - auto tensor_schema = std::make_shared(tensor_fields); - - auto id_array = ConstructInt64ArraySmart((uint8_t *) p_id, n * sizeof(int64_t)); - std::vector arrays{id_array}; - std::vector array_fields{ConstructInt64Field("id")}; - auto array_schema = std::make_shared(tensor_fields); - - auto dataset = std::make_shared(std::move(arrays), array_schema, - std::move(tensors), tensor_schema); - - return dataset; -} - -DatasetPtr -generate_queries(int64_t n, int64_t d, int64_t k, int64_t base) { - size_t size = sizeof(float) * n * d; - auto v = (float *) malloc(size); - // TODO: check malloc - for (auto i = 0; i < n; ++i) { - for (auto j = 0; j < d; ++j) { - v[i * d + j] = float(base + i); - } - } - - std::vector data; - auto buffer = MakeMutableBufferSmart((uint8_t *) v, size); - std::vector shape{n, d}; - auto float_type = std::make_shared(); - auto tensor = std::make_shared(float_type, buffer, shape); - data.push_back(tensor); - - Config meta; - meta[META_ROWS] = int64_t (n); - meta[META_DIM] = int64_t (d); - meta[META_K] = int64_t (k); - - auto type = std::make_shared(); - auto field = std::make_shared("data", type); - std::vector fields{field}; - auto schema = std::make_shared(fields); - - return std::make_shared(data, schema); -} - - -int -main(int argc, char *argv[]) { - auto kdt_index = std::make_shared(); - - const auto d = 10; - const auto k = 3; - const auto nquery = 10; - - // ID [0, 99] - auto train = generate_dataset(100, d, 0); - // ID [100] - auto base = generate_dataset(1, d, 0); - auto queries = generate_queries(nquery, d, k, 0); - - // Build Preprocessor - auto preprocessor = kdt_index->BuildPreprocessor(train, Config()); - - // Set Preprocessor - kdt_index->set_preprocessor(preprocessor); - - Config train_config; - train_config["TPTNumber"] = "64"; - // Train - kdt_index->Train(train, train_config); - - // Add - kdt_index->Add(base, Config()); - - auto binary = kdt_index->Serialize(); - auto new_index = std::make_shared(); - new_index->Load(binary); -// auto new_index = kdt_index; - - Config search_config; - search_config[META_K] = int64_t (k); - - // Search - auto result = new_index->Search(queries, search_config); - - // Print Result - { - auto ids = result->array()[0]; - auto dists = result->array()[1]; - - std::stringstream ss_id; - std::stringstream ss_dist; - for (auto i = 0; i < nquery; i++) { - for (auto j = 0; j < k; ++j) { - ss_id << *ids->data()->GetValues(1, i * k + j) << " "; - ss_dist << *dists->data()->GetValues(1, i * k + j) << " "; - } - ss_id << std::endl; - ss_dist << std::endl; - } - std::cout << "id\n" << ss_id.str() << std::endl; - std::cout << "dist\n" << ss_dist.str() << std::endl; - } -} - - diff --git a/cpp/src/core/test/test_nsg/CMakeLists.txt b/cpp/src/core/test/test_nsg/CMakeLists.txt deleted file mode 100644 index 17b62cd40ddd4de8ea09d345be840e4d7c447a47..0000000000000000000000000000000000000000 --- a/cpp/src/core/test/test_nsg/CMakeLists.txt +++ /dev/null @@ -1,44 +0,0 @@ -############################## -include_directories(/usr/local/include/gperftools) -link_directories(/usr/local/lib) - -add_definitions(-std=c++11 -O3 -lboost -march=native -Wall -DINFO) - -find_package(OpenMP) -if (OPENMP_FOUND) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") -else () - message(FATAL_ERROR "no OpenMP supprot") -endif () -message(${OpenMP_CXX_FLAGS}) - -include_directories(${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/nsg) - -aux_source_directory(${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/nsg nsg_src) - -set(interface_src - ${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/ivf.cpp - ${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/gpu_ivf.cpp - ${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/cloner.cpp - ${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/idmap.cpp - ${CORE_SOURCE_DIR}/src/knowhere/index/vector_index/nsg_index.cpp - ${CORE_SOURCE_DIR}/src/knowhere/adapter/structure.cpp - ${CORE_SOURCE_DIR}/src/knowhere/common/exception.cpp - ${CORE_SOURCE_DIR}/src/knowhere/common/timer.cpp - ../utils.cpp - ) - -if(NOT TARGET test_nsg) - add_executable(test_nsg - test_nsg.cpp - ${interface_src} - ${nsg_src} - ${util_srcs} - ) -endif() - -target_link_libraries(test_nsg ${depend_libs} ${unittest_libs} ${basic_libs}) -############################## - -install(TARGETS test_nsg DESTINATION unittest) \ No newline at end of file diff --git a/cpp/src/core/test/test_nsg/test_nsg.cpp b/cpp/src/core/test/test_nsg/test_nsg.cpp deleted file mode 100644 index b1b1edef46efc639c921332a6643bb94f226c68f..0000000000000000000000000000000000000000 --- a/cpp/src/core/test/test_nsg/test_nsg.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - - -#include -#include - -#include "knowhere/common/exception.h" -#include "knowhere/index/vector_index/gpu_ivf.h" -#include "knowhere/index/vector_index/nsg_index.h" -#include "knowhere/index/vector_index/nsg/nsg_io.h" - -#include "../utils.h" - - -using namespace zilliz::knowhere; -using ::testing::TestWithParam; -using ::testing::Values; -using ::testing::Combine; - -constexpr int64_t DEVICE_ID = 0; - -class NSGInterfaceTest : public DataGen, public TestWithParam<::std::tuple> { - protected: - void SetUp() override { - //Init_with_default(); - FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, 1024*1024*200, 1024*1024*600, 2); - Generate(256, 10000, 1); - index_ = std::make_shared(); - std::tie(train_cfg, search_cfg) = GetParam(); - } - - void TearDown() override { - FaissGpuResourceMgr::GetInstance().Free(); - } - - protected: - std::shared_ptr index_; - Config train_cfg; - Config search_cfg; -}; - -INSTANTIATE_TEST_CASE_P(NSGparameters, NSGInterfaceTest, - Values(std::make_tuple( - // search length > out_degree - Config::object{{"nlist", 128}, {"nprobe", 50}, {"knng", 100}, {"metric_type", "L2"}, - {"search_length", 60}, {"out_degree", 70}, {"candidate_pool_size", 500}}, - Config::object{{"k", 20}, {"search_length", 30}})) -); - -void AssertAnns(const DatasetPtr &result, - const int &nq, - const int &k) { - auto ids = result->array()[0]; - for (auto i = 0; i < nq; i++) { - EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); - } -} - -TEST_P(NSGInterfaceTest, basic_test) { - assert(!xb.empty()); - - auto model = index_->Train(base_dataset, train_cfg); - auto result = index_->Search(query_dataset, search_cfg); - AssertAnns(result, nq, k); - - auto binaryset = index_->Serialize(); - auto new_index = std::make_shared(); - new_index->Load(binaryset); - auto new_result = new_index->Search(query_dataset, Config::object{{"k", k}}); - AssertAnns(result, nq, k); - - ASSERT_EQ(index_->Count(), nb); - ASSERT_EQ(index_->Dimension(), dim); - ASSERT_THROW({index_->Clone();}, zilliz::knowhere::KnowhereException); - ASSERT_NO_THROW({ - index_->Add(base_dataset, Config()); - index_->Seal(); - }); - - { - //std::cout << "k = 1" << std::endl; - //new_index->Search(GenQuery(1), Config::object{{"k", 1}}); - //new_index->Search(GenQuery(10), Config::object{{"k", 1}}); - //new_index->Search(GenQuery(100), Config::object{{"k", 1}}); - //new_index->Search(GenQuery(1000), Config::object{{"k", 1}}); - //new_index->Search(GenQuery(10000), Config::object{{"k", 1}}); - - //std::cout << "k = 5" << std::endl; - //new_index->Search(GenQuery(1), Config::object{{"k", 5}}); - //new_index->Search(GenQuery(20), Config::object{{"k", 5}}); - //new_index->Search(GenQuery(100), Config::object{{"k", 5}}); - //new_index->Search(GenQuery(300), Config::object{{"k", 5}}); - //new_index->Search(GenQuery(500), Config::object{{"k", 5}}); - } -} - diff --git a/cpp/src/core/test/utils.cpp b/cpp/src/core/test/utils.cpp deleted file mode 100644 index 4e240f5f74d4288843513ded37c894b82b657e9d..0000000000000000000000000000000000000000 --- a/cpp/src/core/test/utils.cpp +++ /dev/null @@ -1,152 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - - -#include "utils.h" - -INITIALIZE_EASYLOGGINGPP - -void InitLog() { - el::Configurations defaultConf; - defaultConf.setToDefault(); - defaultConf.set(el::Level::Debug, - el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)"); - el::Loggers::reconfigureLogger("default", defaultConf); -} - -void DataGen::Init_with_default() { - Generate(dim, nb, nq); -} - -void DataGen::Generate(const int &dim, const int &nb, const int &nq) { - this->nb = nb; - this->nq = nq; - this->dim = dim; - - GenAll(dim, nb, xb, ids, nq, xq); - assert(xb.size() == dim * nb); - assert(xq.size() == dim * nq); - - base_dataset = generate_dataset(nb, dim, xb.data(), ids.data()); - query_dataset = generate_query_dataset(nq, dim, xq.data()); - -} -zilliz::knowhere::DatasetPtr DataGen::GenQuery(const int &nq) { - xq.resize(nq * dim); - for (size_t i = 0; i < nq * dim; ++i) { - xq[i] = xb[i]; - } - return generate_query_dataset(nq, dim, xq.data()); -} - -void GenAll(const int64_t dim, - const int64_t &nb, - std::vector &xb, - std::vector &ids, - const int64_t &nq, - std::vector &xq) { - xb.resize(nb * dim); - xq.resize(nq * dim); - ids.resize(nb); - GenAll(dim, nb, xb.data(), ids.data(), nq, xq.data()); -} - -void GenAll(const int64_t &dim, - const int64_t &nb, - float *xb, - int64_t *ids, - const int64_t &nq, - float *xq) { - GenBase(dim, nb, xb, ids); - for (size_t i = 0; i < nq * dim; ++i) { - xq[i] = xb[i]; - } -} - -void GenBase(const int64_t &dim, - const int64_t &nb, - float *xb, - int64_t *ids) { - for (auto i = 0; i < nb; ++i) { - for (auto j = 0; j < dim; ++j) { - //p_data[i * d + j] = float(base + i); - xb[i * dim + j] = drand48(); - } - xb[dim * i] += i / 1000.; - ids[i] = i; - } -} - -FileIOReader::FileIOReader(const std::string &fname) { - name = fname; - fs = std::fstream(name, std::ios::in | std::ios::binary); -} - -FileIOReader::~FileIOReader() { - fs.close(); -} - -size_t FileIOReader::operator()(void *ptr, size_t size) { - fs.read(reinterpret_cast(ptr), size); - return size; -} - -FileIOWriter::FileIOWriter(const std::string &fname) { - name = fname; - fs = std::fstream(name, std::ios::out | std::ios::binary); -} - -FileIOWriter::~FileIOWriter() { - fs.close(); -} - -size_t FileIOWriter::operator()(void *ptr, size_t size) { - fs.write(reinterpret_cast(ptr), size); - return size; -} - -using namespace zilliz::knowhere; - -DatasetPtr -generate_dataset(int64_t nb, int64_t dim, float *xb, long *ids) { - std::vector shape{nb, dim}; - auto tensor = ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape); - std::vector tensors{tensor}; - std::vector tensor_fields{ConstructFloatField("data")}; - auto tensor_schema = std::make_shared(tensor_fields); - - auto id_array = ConstructInt64Array((uint8_t *) ids, nb * sizeof(int64_t)); - std::vector arrays{id_array}; - std::vector array_fields{ConstructInt64Field("id")}; - auto array_schema = std::make_shared(tensor_fields); - - auto dataset = std::make_shared(std::move(arrays), array_schema, - std::move(tensors), tensor_schema); - return dataset; -} - -DatasetPtr -generate_query_dataset(int64_t nb, int64_t dim, float *xb) { - std::vector shape{nb, dim}; - auto tensor = ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape); - std::vector tensors{tensor}; - std::vector tensor_fields{ConstructFloatField("data")}; - auto tensor_schema = std::make_shared(tensor_fields); - - auto dataset = std::make_shared(std::move(tensors), tensor_schema); - return dataset; -} diff --git a/cpp/src/core/thirdparty/SPTAG/AnnService/src/Core/MetadataSet.cpp b/cpp/src/core/thirdparty/SPTAG/AnnService/src/Core/MetadataSet.cpp index 2eefcea68f3c4ace8bc9e3614236e71eca123186..a5d410ce5e9490eedc99c817cce9886e4b132222 100644 --- a/cpp/src/core/thirdparty/SPTAG/AnnService/src/Core/MetadataSet.cpp +++ b/cpp/src/core/thirdparty/SPTAG/AnnService/src/Core/MetadataSet.cpp @@ -179,13 +179,13 @@ FileMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& ErrorCode FileMetadataSet::SaveMetadataToMemory(void **pGraphMemFile, int64_t &len) { - // TODO: serialize file to mem? + // TODO(lxj): serialize file to mem? return ErrorCode::Fail; } ErrorCode FileMetadataSet::LoadMetadataFromMemory(void *pGraphMemFile) { - // TODO: not support yet + // TODO(lxj): not support yet return ErrorCode::Fail; } diff --git a/cpp/src/core/test/CMakeLists.txt b/cpp/src/core/unittest/CMakeLists.txt similarity index 81% rename from cpp/src/core/test/CMakeLists.txt rename to cpp/src/core/unittest/CMakeLists.txt index 0cf83de24748de7f185d01b0ba9a54dc8ae8a140..0a52a2ed837577886211f374cf630ff4179be00d 100644 --- a/cpp/src/core/test/CMakeLists.txt +++ b/cpp/src/core/unittest/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${CORE_SOURCE_DIR}/thirdparty) include_directories(${CORE_SOURCE_DIR}/thirdparty/SPTAG/AnnService) include_directories(${CORE_SOURCE_DIR}/knowhere) +include_directories(${CORE_SOURCE_DIR}) include_directories(/usr/local/cuda/include) link_directories(/usr/local/cuda/lib64) link_directories(${CORE_SOURCE_DIR}/thirdparty/tbb) @@ -29,7 +30,7 @@ set(util_srcs ${CORE_SOURCE_DIR}/knowhere/knowhere/adapter/ArrowAdapter.cpp ${CORE_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp ${CORE_SOURCE_DIR}/knowhere/knowhere/common/Timer.cpp - utils.cpp + ${CORE_SOURCE_DIR}/unittest/utils.cpp ) # @@ -53,18 +54,10 @@ target_link_libraries(test_ivf ${depend_libs} ${unittest_libs} ${basic_libs}) # set(idmap_srcs - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexGPUIVF.cpp - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVF.cpp - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexGPUIVFPQ.cpp - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp - ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexGPUIVFSQ.cpp ) if(NOT TARGET test_idmap) - add_executable(test_idmap test_idmap.cpp ${idmap_srcs} ${util_srcs}) + add_executable(test_idmap test_idmap.cpp ${idmap_srcs} ${ivf_srcs} ${util_srcs}) endif() target_link_libraries(test_idmap ${depend_libs} ${unittest_libs} ${basic_libs}) @@ -87,5 +80,5 @@ install(TARGETS test_idmap DESTINATION unittest) install(TARGETS test_kdt DESTINATION unittest) #add_subdirectory(faiss_ori) -#add_subdirectory(test_nsg) +add_subdirectory(test_nsg) diff --git a/cpp/src/core/unittest/SPTAG.cpp b/cpp/src/core/unittest/SPTAG.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62198d5495810db382cd6fd2a2b9dfe597811331 --- /dev/null +++ b/cpp/src/core/unittest/SPTAG.cpp @@ -0,0 +1,50 @@ +//// Licensed to the Apache Software Foundation (ASF) under one +//// or more contributor license agreements. See the NOTICE file +//// distributed with this work for additional information +//// regarding copyright ownership. The ASF licenses this file +//// to you under the Apache License, Version 2.0 (the +//// "License"); you may not use this file except in compliance +//// with the License. You may obtain a copy of the License at +//// +//// http://www.apache.org/licenses/LICENSE-2.0 +//// +//// Unless required by applicable law or agreed to in writing, +//// software distributed under the License is distributed on an +//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +//// KIND, either express or implied. See the License for the +//// specific language governing permissions and limitations +//// under the License. +// +//#include +//#include +//#include +//#include +//#include +// +// int +// main(int argc, char* argv[]) { +// using namespace SPTAG; +// const int d = 128; +// const int n = 100; +// +// auto p_data = new float[n * d]; +// +// auto index = VectorIndex::CreateInstance(IndexAlgoType::KDT, VectorValueType::Float); +// +// std::random_device rd; +// std::mt19937 mt(rd()); +// std::uniform_real_distribution dist(1.0, 2.0); +// +// for (auto i = 0; i < n; i++) { +// for (auto j = 0; j < d; j++) { +// p_data[i * d + j] = dist(mt) - 1; +// } +// } +// std::cout << "generate random n * d finished."; +// ByteArray data((uint8_t*)p_data, n * d * sizeof(float), true); +// +// auto vectorset = std::make_shared(data, VectorValueType::Float, d, n); +// index->BuildIndex(vectorset, nullptr); +// +// std::cout << index->GetFeatureDim(); +//} diff --git a/cpp/src/core/test/faiss_ori/CMakeLists.txt b/cpp/src/core/unittest/faiss_ori/CMakeLists.txt similarity index 100% rename from cpp/src/core/test/faiss_ori/CMakeLists.txt rename to cpp/src/core/unittest/faiss_ori/CMakeLists.txt diff --git a/cpp/src/core/test/faiss_ori/gpuresource_test.cpp b/cpp/src/core/unittest/faiss_ori/gpuresource_test.cpp similarity index 60% rename from cpp/src/core/test/faiss_ori/gpuresource_test.cpp rename to cpp/src/core/unittest/faiss_ori/gpuresource_test.cpp index 0dcc2766cc27ea2b7ea055dc87bfe5a1fa84e4e7..90383b944cab7c71bab197ca21fc6b5825d91ca8 100644 --- a/cpp/src/core/test/faiss_ori/gpuresource_test.cpp +++ b/cpp/src/core/unittest/faiss_ori/gpuresource_test.cpp @@ -17,46 +17,45 @@ #include -#include -#include -#include #include +#include #include +#include +#include #include -#include #include #include - -using namespace std::chrono_literals; +#include class TestGpuRes { public: TestGpuRes() { res_ = new faiss::gpu::StandardGpuResources; } + ~TestGpuRes() { delete res_; delete index_; } - std::shared_ptr Do() { - int d = 128; // dimension - int nb = 100000; // database size - int nq = 100; // nb of queries + + std::shared_ptr + Do() { + int d = 128; // dimension + int nb = 100000; // database size + int nq = 100; // nb of queries int nlist = 1638; - float *xb = new float[d * nb]; - float *xq = new float[d * nq]; + float* xb = new float[d * nb]; + float* xq = new float[d * nq]; for (int i = 0; i < nb; i++) { - for (int j = 0; j < d; j++) - xb[d * i + j] = drand48(); + for (int j = 0; j < d; j++) xb[d * i + j] = drand48(); xb[d * i] += i / 1000.; } for (int i = 0; i < nq; i++) { - for (int j = 0; j < d; j++) - xq[d * i + j] = drand48(); + for (int j = 0; j < d; j++) xq[d * i + j] = drand48(); xq[d * i] += i / 1000.; } @@ -68,9 +67,10 @@ class TestGpuRes { host_index.reset(faiss::gpu::index_gpu_to_cpu(index_)); return host_index; } + private: - faiss::gpu::GpuResources *res_ = nullptr; - faiss::Index *index_ = nullptr; + faiss::gpu::GpuResources* res_ = nullptr; + faiss::Index* index_ = nullptr; }; TEST(gpuresource, resource) { @@ -79,30 +79,28 @@ TEST(gpuresource, resource) { } TEST(test, resource_re) { - int d = 128; // dimension - int nb = 1000000; // database size - int nq = 100; // nb of queries + int d = 128; // dimension + int nb = 1000000; // database size + int nq = 100; // nb of queries int nlist = 16384; int k = 100; - float *xb = new float[d * nb]; - float *xq = new float[d * nq]; + float* xb = new float[d * nb]; + float* xq = new float[d * nq]; for (int i = 0; i < nb; i++) { - for (int j = 0; j < d; j++) - xb[d * i + j] = drand48(); + for (int j = 0; j < d; j++) xb[d * i + j] = drand48(); xb[d * i] += i / 1000.; } for (int i = 0; i < nq; i++) { - for (int j = 0; j < d; j++) - xq[d * i + j] = drand48(); + for (int j = 0; j < d; j++) xq[d * i + j] = drand48(); xq[d * i] += i / 1000.; } auto elems = nq * k; - auto res_ids = (int64_t *) malloc(sizeof(int64_t) * elems); - auto res_dis = (float *) malloc(sizeof(float) * elems); + auto res_ids = (int64_t*)malloc(sizeof(int64_t) * elems); + auto res_dis = (float*)malloc(sizeof(float) * elems); faiss::gpu::StandardGpuResources res; auto cpu_index = faiss::index_factory(d, "IVF16384, Flat"); @@ -117,7 +115,7 @@ TEST(test, resource_re) { auto load = [&] { std::cout << "start" << std::endl; faiss::gpu::StandardGpuResources res; - //res.noTempMemory(); + // res.noTempMemory(); for (int l = 0; l < 100; ++l) { auto x = faiss::gpu::index_cpu_to_gpu(&res, 1, new_index); delete x; @@ -126,42 +124,42 @@ TEST(test, resource_re) { }; auto search = [&] { - faiss::gpu::StandardGpuResources res; - auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 1, new_index); - std::cout << "search start" << std::endl; - for (int l = 0; l < 10000; ++l) { - device_index->search(nq,xq,10, res_dis, res_ids); - } - std::cout << "search finish" << std::endl; - delete device_index; - delete cpu_index; + faiss::gpu::StandardGpuResources res; + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 1, new_index); + std::cout << "search start" << std::endl; + for (int l = 0; l < 10000; ++l) { + device_index->search(nq, xq, 10, res_dis, res_ids); + } + std::cout << "search finish" << std::endl; + delete device_index; + delete cpu_index; }; load(); search(); std::thread t1(search); - std::this_thread::sleep_for(1s); + std::this_thread::sleep_for(std::chrono::seconds(1)); std::thread t2(load); t1.join(); t2.join(); std::cout << "finish clone" << std::endl; - //std::this_thread::sleep_for(5s); + // std::this_thread::sleep_for(5s); // - //auto device_index_2 = faiss::gpu::index_cpu_to_gpu(&res, 1, cpu_index); - //device_index->train(nb, xb); - //device_index->add(nb, xb); + // auto device_index_2 = faiss::gpu::index_cpu_to_gpu(&res, 1, cpu_index); + // device_index->train(nb, xb); + // device_index->add(nb, xb); - //std::cout << "finish clone" << std::endl; - //std::this_thread::sleep_for(5s); + // std::cout << "finish clone" << std::endl; + // std::this_thread::sleep_for(5s); - //std::this_thread::sleep_for(2s); - //std::cout << "start clone" << std::endl; - //auto new_index = faiss::clone_index(device_index); - //std::cout << "start search" << std::endl; - //new_index->search(nq, xq, k, res_dis, res_ids); + // std::this_thread::sleep_for(2s); + // std::cout << "start clone" << std::endl; + // auto new_index = faiss::clone_index(device_index); + // std::cout << "start search" << std::endl; + // new_index->search(nq, xq, k, res_dis, res_ids); - //std::cout << "start clone" << std::endl; + // std::cout << "start clone" << std::endl; //{ // faiss::gpu::StandardGpuResources res; // auto cpu_index = faiss::index_factory(d, "IVF1638, Flat"); @@ -174,5 +172,5 @@ TEST(test, resource_re) { // std::cout << "finish clone" << std::endl; //} // - //std::cout << "finish clone" << std::endl; + // std::cout << "finish clone" << std::endl; } diff --git a/cpp/src/core/unittest/kdtree.cpp b/cpp/src/core/unittest/kdtree.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b1e428a8052311115dc1f0c44cbefaa5f7eccf6 --- /dev/null +++ b/cpp/src/core/unittest/kdtree.cpp @@ -0,0 +1,149 @@ +//// Licensed to the Apache Software Foundation (ASF) under one +//// or more contributor license agreements. See the NOTICE file +//// distributed with this work for additional information +//// regarding copyright ownership. The ASF licenses this file +//// to you under the Apache License, Version 2.0 (the +//// "License"); you may not use this file except in compliance +//// with the License. You may obtain a copy of the License at +//// +//// http://www.apache.org/licenses/LICENSE-2.0 +//// +//// Unless required by applicable law or agreed to in writing, +//// software distributed under the License is distributed on an +//// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +//// KIND, either express or implied. See the License for the +//// specific language governing permissions and limitations +//// under the License. +// +//#include +//#include +//#include "knowhere/adapter/sptag.h" +//#include "knowhere/adapter/structure.h" +//#include "knowhere/index/vector_index/cpu_kdt_rng.h" +//#include "knowhere/index/vector_index/definitions.h" +// +// namespace { +// +// namespace kn = knowhere; +// +//} // namespace +// +// kn::DatasetPtr +// generate_dataset(int64_t n, int64_t d, int64_t base) { +// auto elems = n * d; +// auto p_data = (float*)malloc(elems * sizeof(float)); +// auto p_id = (int64_t*)malloc(elems * sizeof(int64_t)); +// assert(p_data != nullptr && p_id != nullptr); +// +// for (auto i = 0; i < n; ++i) { +// for (auto j = 0; j < d; ++j) { +// p_data[i * d + j] = float(base + i); +// } +// p_id[i] = i; +// } +// +// std::vector shape{n, d}; +// auto tensor = ConstructFloatTensorSmart((uint8_t*)p_data, elems * sizeof(float), shape); +// std::vector tensors{tensor}; +// std::vector tensor_fields{ConstructFloatField("data")}; +// auto tensor_schema = std::make_shared(tensor_fields); +// +// auto id_array = ConstructInt64ArraySmart((uint8_t*)p_id, n * sizeof(int64_t)); +// std::vector arrays{id_array}; +// std::vector array_fields{ConstructInt64Field("id")}; +// auto array_schema = std::make_shared(tensor_fields); +// +// auto dataset = std::make_shared(std::move(arrays), array_schema, std::move(tensors), tensor_schema); +// +// return dataset; +//} +// +// kn::DatasetPtr +// generate_queries(int64_t n, int64_t d, int64_t k, int64_t base) { +// size_t size = sizeof(float) * n * d; +// auto v = (float*)malloc(size); +// // TODO(lxj): check malloc +// for (auto i = 0; i < n; ++i) { +// for (auto j = 0; j < d; ++j) { +// v[i * d + j] = float(base + i); +// } +// } +// +// std::vector data; +// auto buffer = MakeMutableBufferSmart((uint8_t*)v, size); +// std::vector shape{n, d}; +// auto float_type = std::make_shared(); +// auto tensor = std::make_shared(float_type, buffer, shape); +// data.push_back(tensor); +// +// Config meta; +// meta[META_ROWS] = int64_t(n); +// meta[META_DIM] = int64_t(d); +// meta[META_K] = int64_t(k); +// +// auto type = std::make_shared(); +// auto field = std::make_shared("data", type); +// std::vector fields{field}; +// auto schema = std::make_shared(fields); +// +// return std::make_shared(data, schema); +//} +// +// int +// main(int argc, char* argv[]) { +// auto kdt_index = std::make_shared(); +// +// const auto d = 10; +// const auto k = 3; +// const auto nquery = 10; +// +// // ID [0, 99] +// auto train = generate_dataset(100, d, 0); +// // ID [100] +// auto base = generate_dataset(1, d, 0); +// auto queries = generate_queries(nquery, d, k, 0); +// +// // Build Preprocessor +// auto preprocessor = kdt_index->BuildPreprocessor(train, Config()); +// +// // Set Preprocessor +// kdt_index->set_preprocessor(preprocessor); +// +// Config train_config; +// train_config["TPTNumber"] = "64"; +// // Train +// kdt_index->Train(train, train_config); +// +// // Add +// kdt_index->Add(base, Config()); +// +// auto binary = kdt_index->Serialize(); +// auto new_index = std::make_shared(); +// new_index->Load(binary); +// // auto new_index = kdt_index; +// +// Config search_config; +// search_config[META_K] = int64_t(k); +// +// // Search +// auto result = new_index->Search(queries, search_config); +// +// // Print Result +// { +// auto ids = result->array()[0]; +// auto dists = result->array()[1]; +// +// std::stringstream ss_id; +// std::stringstream ss_dist; +// for (auto i = 0; i < nquery; i++) { +// for (auto j = 0; j < k; ++j) { +// ss_id << *ids->data()->GetValues(1, i * k + j) << " "; +// ss_dist << *dists->data()->GetValues(1, i * k + j) << " "; +// } +// ss_id << std::endl; +// ss_dist << std::endl; +// } +// std::cout << "id\n" << ss_id.str() << std::endl; +// std::cout << "dist\n" << ss_dist.str() << std::endl; +// } +//} diff --git a/cpp/src/core/test/sift.50NN.graph b/cpp/src/core/unittest/sift.50NN.graph similarity index 100% rename from cpp/src/core/test/sift.50NN.graph rename to cpp/src/core/unittest/sift.50NN.graph diff --git a/cpp/src/core/test/siftsmall_base.fvecs b/cpp/src/core/unittest/siftsmall_base.fvecs similarity index 100% rename from cpp/src/core/test/siftsmall_base.fvecs rename to cpp/src/core/unittest/siftsmall_base.fvecs diff --git a/cpp/src/core/test/test_idmap.cpp b/cpp/src/core/unittest/test_idmap.cpp similarity index 72% rename from cpp/src/core/test/test_idmap.cpp rename to cpp/src/core/unittest/test_idmap.cpp index 3d71bca93177aedbfae52b974de21f3bd5faa303..e907d309eb36c4563aeb420d4ef9cace25d75d51 100644 --- a/cpp/src/core/test/test_idmap.cpp +++ b/cpp/src/core/unittest/test_idmap.cpp @@ -15,51 +15,51 @@ // specific language governing permissions and limitations // under the License. - #include - #include -#include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/adapter/Structure.h" -#include "knowhere/index/vector_index/helpers/Cloner.h" #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" + +#include "unittest/utils.h" -#include "utils.h" +namespace { +namespace kn = knowhere; -using namespace zilliz::knowhere; -using namespace zilliz::knowhere::cloner; +} // namespace static int device_id = 0; class IDMAPTest : public DataGen, public ::testing::Test { protected: - void SetUp() override { - FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024*1024*200, 1024*1024*300, 2); + void + SetUp() override { + kn::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024 * 1024 * 200, 1024 * 1024 * 300, 2); Init_with_default(); - index_ = std::make_shared(); + index_ = std::make_shared(); } - void TearDown() override { - FaissGpuResourceMgr::GetInstance().Free(); + void + TearDown() override { + kn::FaissGpuResourceMgr::GetInstance().Free(); } protected: - IDMAPPtr index_ = nullptr; + kn::IDMAPPtr index_ = nullptr; }; -void AssertAnns(const DatasetPtr &result, - const int &nq, - const int &k) { +void +AssertAnns(const kn::DatasetPtr& result, const int& nq, const int& k) { auto ids = result->array()[0]; for (auto i = 0; i < nq; i++) { EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); } } -void PrintResult(const DatasetPtr &result, - const int &nq, - const int &k) { +void +PrintResult(const kn::DatasetPtr& result, const int& nq, const int& k) { auto ids = result->array()[0]; auto dists = result->array()[1]; @@ -80,10 +80,10 @@ void PrintResult(const DatasetPtr &result, TEST_F(IDMAPTest, idmap_basic) { ASSERT_TRUE(!xb.empty()); - auto conf = std::make_shared(); + auto conf = std::make_shared(); conf->d = dim; conf->k = k; - conf->metric_type = METRICTYPE::L2; + conf->metric_type = kn::METRICTYPE::L2; index_->Train(conf); index_->Add(base_dataset, conf); @@ -97,7 +97,7 @@ TEST_F(IDMAPTest, idmap_basic) { index_->Seal(); auto binaryset = index_->Serialize(); - auto new_index = std::make_shared(); + auto new_index = std::make_shared(); new_index->Load(binaryset); auto re_result = index_->Search(query_dataset, conf); AssertAnns(re_result, nq, k); @@ -105,23 +105,23 @@ TEST_F(IDMAPTest, idmap_basic) { } TEST_F(IDMAPTest, idmap_serialize) { - auto serialize = [](const std::string &filename, BinaryPtr &bin, uint8_t *ret) { + auto serialize = [](const std::string& filename, kn::BinaryPtr& bin, uint8_t* ret) { FileIOWriter writer(filename); - writer(static_cast(bin->data.get()), bin->size); + writer(static_cast(bin->data.get()), bin->size); FileIOReader reader(filename); reader(ret, bin->size); }; - auto conf = std::make_shared(); + auto conf = std::make_shared(); conf->d = dim; conf->k = k; - conf->metric_type = METRICTYPE::L2; + conf->metric_type = kn::METRICTYPE::L2; { // serialize index index_->Train(conf); - index_->Add(base_dataset, Config()); + index_->Add(base_dataset, kn::Config()); auto re_result = index_->Search(query_dataset, conf); AssertAnns(re_result, nq, k); PrintResult(re_result, nq, k); @@ -151,10 +151,10 @@ TEST_F(IDMAPTest, idmap_serialize) { TEST_F(IDMAPTest, copy_test) { ASSERT_TRUE(!xb.empty()); - auto conf = std::make_shared(); + auto conf = std::make_shared(); conf->d = dim; conf->k = k; - conf->metric_type = METRICTYPE::L2; + conf->metric_type = kn::METRICTYPE::L2; index_->Train(conf); index_->Add(base_dataset, conf); @@ -164,7 +164,7 @@ TEST_F(IDMAPTest, copy_test) { ASSERT_TRUE(index_->GetRawIds() != nullptr); auto result = index_->Search(query_dataset, conf); AssertAnns(result, nq, k); - //PrintResult(result, nq, k); + // PrintResult(result, nq, k); { // clone @@ -175,13 +175,13 @@ TEST_F(IDMAPTest, copy_test) { { // cpu to gpu - auto clone_index = CopyCpuToGpu(index_, device_id, conf); + auto clone_index = kn::cloner::CopyCpuToGpu(index_, device_id, conf); auto clone_result = clone_index->Search(query_dataset, conf); AssertAnns(clone_result, nq, k); - ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawVectors(); }, - zilliz::knowhere::KnowhereException); - ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawIds(); }, - zilliz::knowhere::KnowhereException); + ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawVectors(); }, + knowhere::KnowhereException); + ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawIds(); }, + knowhere::KnowhereException); auto binary = clone_index->Serialize(); clone_index->Load(binary); @@ -193,15 +193,15 @@ TEST_F(IDMAPTest, copy_test) { AssertAnns(clone_gpu_res, nq, k); // gpu to cpu - auto host_index = CopyGpuToCpu(clone_index, conf); + auto host_index = kn::cloner::CopyGpuToCpu(clone_index, conf); auto host_result = host_index->Search(query_dataset, conf); AssertAnns(host_result, nq, k); - ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawVectors() != nullptr); - ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawIds() != nullptr); + ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawVectors() != nullptr); + ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawIds() != nullptr); // gpu to gpu - auto device_index = CopyCpuToGpu(index_, device_id, conf); - auto new_device_index = std::static_pointer_cast(device_index)->CopyGpuToGpu(device_id, conf); + auto device_index = kn::cloner::CopyCpuToGpu(index_, device_id, conf); + auto new_device_index = std::static_pointer_cast(device_index)->CopyGpuToGpu(device_id, conf); auto device_result = new_device_index->Search(query_dataset, conf); AssertAnns(device_result, nq, k); } diff --git a/cpp/src/core/test/test_ivf.cpp b/cpp/src/core/unittest/test_ivf.cpp similarity index 68% rename from cpp/src/core/test/test_ivf.cpp rename to cpp/src/core/unittest/test_ivf.cpp index f1d195654811b7e30a8f2a4b0c465375f0ca6a04..38d3c76108802ba17414b18c4097355a973bfeae 100644 --- a/cpp/src/core/test/test_ivf.cpp +++ b/cpp/src/core/unittest/test_ivf.cpp @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #include #include @@ -25,44 +24,47 @@ #include #include -#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" #include "knowhere/common/Exception.h" #include "knowhere/common/Timer.h" -#include "knowhere/adapter/Structure.h" -#include "knowhere/index/vector_index/helpers/Cloner.h" -#include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexGPUIVF.h" -#include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexGPUIVFPQ.h" -#include "knowhere/index/vector_index/IndexIVFSQ.h" #include "knowhere/index/vector_index/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" + +#include "unittest/utils.h" -#include "utils.h" +namespace { -using namespace zilliz::knowhere; -using namespace zilliz::knowhere::cloner; +namespace kn = knowhere; +} // namespace + +using ::testing::Combine; using ::testing::TestWithParam; using ::testing::Values; -using ::testing::Combine; constexpr int device_id = 0; constexpr int64_t DIM = 128; -constexpr int64_t NB = 1000000/100; +constexpr int64_t NB = 1000000 / 100; constexpr int64_t NQ = 10; constexpr int64_t K = 10; -IVFIndexPtr IndexFactory(const std::string &type) { +kn::IVFIndexPtr +IndexFactory(const std::string& type) { if (type == "IVF") { - return std::make_shared(); + return std::make_shared(); } else if (type == "IVFPQ") { - return std::make_shared(); + return std::make_shared(); } else if (type == "GPUIVF") { - return std::make_shared(device_id); + return std::make_shared(device_id); } else if (type == "GPUIVFPQ") { - return std::make_shared(device_id); + return std::make_shared(device_id); } else if (type == "IVFSQ") { - return std::make_shared(); + return std::make_shared(); } else if (type == "GPUIVFSQ") { return std::make_shared(device_id); } else if (type == "IVFSQHybrid") { @@ -80,24 +82,25 @@ enum class ParameterType { class ParamGenerator { public: - static ParamGenerator& GetInstance(){ + static ParamGenerator& + GetInstance() { static ParamGenerator instance; return instance; } - Config Gen(const ParameterType& type){ + kn::Config + Gen(const ParameterType& type) { if (type == ParameterType::ivf) { - auto tempconf = std::make_shared(); + auto tempconf = std::make_shared(); tempconf->d = DIM; tempconf->gpu_id = device_id; tempconf->nlist = 100; tempconf->nprobe = 16; tempconf->k = K; - tempconf->metric_type = METRICTYPE::L2; + tempconf->metric_type = kn::METRICTYPE::L2; return tempconf; - } - else if (type == ParameterType::ivfpq) { - auto tempconf = std::make_shared(); + } else if (type == ParameterType::ivfpq) { + auto tempconf = std::make_shared(); tempconf->d = DIM; tempconf->gpu_id = device_id; tempconf->nlist = 100; @@ -105,10 +108,9 @@ class ParamGenerator { tempconf->k = K; tempconf->m = 8; tempconf->nbits = 8; - tempconf->metric_type = METRICTYPE::L2; + tempconf->metric_type = kn::METRICTYPE::L2; return tempconf; - } - else if (type == ParameterType::ivfsq || type == ParameterType::ivfsqhybrid) { + } else if (type == ParameterType::ivfsq || type == ParameterType::ivfsqhybrid) { auto tempconf = std::make_shared(); tempconf->d = DIM; tempconf->gpu_id = device_id; @@ -116,29 +118,32 @@ class ParamGenerator { tempconf->nprobe = 16; tempconf->k = K; tempconf->nbits = 8; - tempconf->metric_type = METRICTYPE::L2; + tempconf->metric_type = kn::METRICTYPE::L2; return tempconf; } } }; -class IVFTest - : public DataGen, public TestWithParam<::std::tuple> { +class IVFTest : public DataGen, public TestWithParam<::std::tuple> { protected: - void SetUp() override { + void + SetUp() override { ParameterType parameter_type; std::tie(index_type, parameter_type) = GetParam(); - //Init_with_default(); + // Init_with_default(); Generate(DIM, NB, NQ); index_ = IndexFactory(index_type); conf = ParamGenerator::GetInstance().Gen(parameter_type); - FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024*1024*200, 1024*1024*600, 2); + kn::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024 * 1024 * 200, 1024 * 1024 * 600, 2); } - void TearDown() override { - FaissGpuResourceMgr::GetInstance().Free(); + + void + TearDown() override { + kn::FaissGpuResourceMgr::GetInstance().Free(); } - VectorIndexPtr ChooseTodo() { + VectorIndexPtr + ChooseTodo() { std::vector gpu_idx{"GPUIVFSQ"}; auto finder = std::find(gpu_idx.cbegin(), gpu_idx.cend(), index_type); if (finder != gpu_idx.cend()) { @@ -149,36 +154,29 @@ class IVFTest protected: std::string index_type; - Config conf; - IVFIndexPtr index_ = nullptr; + kn::Config conf; + kn::IVFIndexPtr index_ = nullptr; }; - - INSTANTIATE_TEST_CASE_P(IVFParameters, IVFTest, - Values( - std::make_tuple("IVF", ParameterType::ivf), - std::make_tuple("GPUIVF", ParameterType::ivf), -// std::make_tuple("IVFPQ", ParameterType::ivfpq), -// std::make_tuple("GPUIVFPQ", ParameterType::ivfpq), - std::make_tuple("IVFSQ", ParameterType::ivfsq), - std::make_tuple("GPUIVFSQ", ParameterType::ivfsq) - std::make_tuple("IVFSQHybrid", ParameterType::ivfsqhybrid) - ) -); - -void AssertAnns(const DatasetPtr &result, - const int &nq, - const int &k) { + Values(std::make_tuple("IVF", ParameterType::ivf), + std::make_tuple("GPUIVF", ParameterType::ivf), + // std::make_tuple("IVFPQ", ParameterType::ivfpq), + // std::make_tuple("GPUIVFPQ", ParameterType::ivfpq), + std::make_tuple("IVFSQ", ParameterType::ivfsq), + std::make_tuple("GPUIVFSQ", ParameterType::ivfsq) + std::make_tuple("IVFSQHybrid", ParameterType::ivfsqhybrid))); + +void +AssertAnns(const DatasetPtr& result, const int& nq, const int& k) { auto ids = result->array()[0]; for (auto i = 0; i < nq; i++) { EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); } } -void PrintResult(const DatasetPtr &result, - const int &nq, - const int &k) { +void +PrintResult(const kn::DatasetPtr& result, const int& nq, const int& k) { auto ids = result->array()[0]; auto dists = result->array()[1]; @@ -211,7 +209,7 @@ TEST_P(IVFTest, ivf_basic) { auto new_idx = ChooseTodo(); auto result = new_idx->Search(query_dataset, conf); AssertAnns(result, nq, conf->k); - //PrintResult(result, nq, k); + // PrintResult(result, nq, k); } TEST_P(IVFTest, hybrid) { @@ -229,9 +227,9 @@ TEST_P(IVFTest, hybrid) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dimension(), dim); -// auto new_idx = ChooseTodo(); -// auto result = new_idx->Search(query_dataset, conf); -// AssertAnns(result, nq, conf->k); + // auto new_idx = ChooseTodo(); + // auto result = new_idx->Search(query_dataset, conf); + // AssertAnns(result, nq, conf->k); { auto hybrid_1_idx = std::make_shared(device_id); @@ -266,10 +264,9 @@ TEST_P(IVFTest, hybrid) { AssertAnns(result, nq, conf->k); PrintResult(result, nq, k); } - } -//TEST_P(IVFTest, gpu_to_cpu) { +// TEST_P(IVFTest, gpu_to_cpu) { // if (index_type.find("GPU") == std::string::npos) { return; } // // // else @@ -294,9 +291,9 @@ TEST_P(IVFTest, hybrid) { //} TEST_P(IVFTest, ivf_serialize) { - auto serialize = [](const std::string &filename, BinaryPtr &bin, uint8_t *ret) { + auto serialize = [](const std::string& filename, kn::BinaryPtr& bin, uint8_t* ret) { FileIOWriter writer(filename); - writer(static_cast(bin->data.get()), bin->size); + writer(static_cast(bin->data.get()), bin->size); FileIOReader reader(filename); reader(ret, bin->size); @@ -366,37 +363,36 @@ TEST_P(IVFTest, clone_test) { auto new_idx = ChooseTodo(); auto result = new_idx->Search(query_dataset, conf); AssertAnns(result, nq, conf->k); - //PrintResult(result, nq, k); + // PrintResult(result, nq, k); - auto AssertEqual = [&] (DatasetPtr p1, DatasetPtr p2) { + auto AssertEqual = [&](kn::DatasetPtr p1, kn::DatasetPtr p2) { auto ids_p1 = p1->array()[0]; auto ids_p2 = p2->array()[0]; for (int i = 0; i < nq * k; ++i) { - EXPECT_EQ(*(ids_p2->data()->GetValues(1, i)), - *(ids_p1->data()->GetValues(1, i))); + EXPECT_EQ(*(ids_p2->data()->GetValues(1, i)), *(ids_p1->data()->GetValues(1, i))); } }; -// { -// // clone in place -// std::vector support_idx_vec{"IVF", "GPUIVF", "IVFPQ", "IVFSQ", "GPUIVFSQ"}; -// auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); -// if (finder != support_idx_vec.cend()) { -// EXPECT_NO_THROW({ -// auto clone_index = index_->Clone(); -// auto clone_result = clone_index->Search(query_dataset, conf); -// //AssertAnns(result, nq, conf->k); -// AssertEqual(result, clone_result); -// std::cout << "inplace clone [" << index_type << "] success" << std::endl; -// }); -// } else { -// EXPECT_THROW({ -// std::cout << "inplace clone [" << index_type << "] failed" << std::endl; -// auto clone_index = index_->Clone(); -// }, KnowhereException); -// } -// } + // { + // // clone in place + // std::vector support_idx_vec{"IVF", "GPUIVF", "IVFPQ", "IVFSQ", "GPUIVFSQ"}; + // auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); + // if (finder != support_idx_vec.cend()) { + // EXPECT_NO_THROW({ + // auto clone_index = index_->Clone(); + // auto clone_result = clone_index->Search(query_dataset, conf); + // //AssertAnns(result, nq, conf->k); + // AssertEqual(result, clone_result); + // std::cout << "inplace clone [" << index_type << "] success" << std::endl; + // }); + // } else { + // EXPECT_THROW({ + // std::cout << "inplace clone [" << index_type << "] failed" << std::endl; + // auto clone_index = index_->Clone(); + // }, KnowhereException); + // } + // } { if (index_type == "IVFSQHybrid") { @@ -410,16 +406,18 @@ TEST_P(IVFTest, clone_test) { auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); if (finder != support_idx_vec.cend()) { EXPECT_NO_THROW({ - auto clone_index = CopyGpuToCpu(index_, Config()); - auto clone_result = clone_index->Search(query_dataset, conf); - AssertEqual(result, clone_result); - std::cout << "clone G <=> C [" << index_type << "] success" << std::endl; - }); + auto clone_index = kn::cloner::CopyGpuToCpu(index_, kn::Config()); + auto clone_result = clone_index->Search(query_dataset, conf); + AssertEqual(result, clone_result); + std::cout << "clone G <=> C [" << index_type << "] success" << std::endl; + }); } else { - EXPECT_THROW({ - std::cout << "clone G <=> C [" << index_type << "] failed" << std::endl; - auto clone_index = CopyGpuToCpu(index_, Config()); - }, KnowhereException); + EXPECT_THROW( + { + std::cout << "clone G <=> C [" << index_type << "] failed" << std::endl; + auto clone_index = kn::cloner::CopyGpuToCpu(index_, kn::Config()); + }, + kn::KnowhereException); } } @@ -429,22 +427,24 @@ TEST_P(IVFTest, clone_test) { auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); if (finder != support_idx_vec.cend()) { EXPECT_NO_THROW({ - auto clone_index = CopyCpuToGpu(index_, device_id, Config()); - auto clone_result = clone_index->Search(query_dataset, conf); - AssertEqual(result, clone_result); - std::cout << "clone C <=> G [" << index_type << "] success" << std::endl; - }); + auto clone_index = kn::cloner::CopyCpuToGpu(index_, device_id, kn::Config()); + auto clone_result = clone_index->Search(query_dataset, conf); + AssertEqual(result, clone_result); + std::cout << "clone C <=> G [" << index_type << "] success" << std::endl; + }); } else { - EXPECT_THROW({ - std::cout << "clone C <=> G [" << index_type << "] failed" << std::endl; - auto clone_index = CopyCpuToGpu(index_, device_id, Config()); - }, KnowhereException); + EXPECT_THROW( + { + std::cout << "clone C <=> G [" << index_type << "] failed" << std::endl; + auto clone_index = kn::cloner::CopyCpuToGpu(index_, device_id, kn::Config()); + }, + kn::KnowhereException); } } } TEST_P(IVFTest, seal_test) { - //FaissGpuResourceMgr::GetInstance().InitDevice(device_id); + // FaissGpuResourceMgr::GetInstance().InitDevice(device_id); std::vector support_idx_vec{"GPUIVF", "GPUIVFSQ", "IVFSQHybrid"}; auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); @@ -466,44 +466,44 @@ TEST_P(IVFTest, seal_test) { auto result = new_idx->Search(query_dataset, conf); AssertAnns(result, nq, conf->k); - auto cpu_idx = CopyGpuToCpu(index_, Config()); + auto cpu_idx = kn::cloner::CopyGpuToCpu(index_, kn::Config()); - TimeRecorder tc("CopyToGpu"); - CopyCpuToGpu(cpu_idx, device_id, Config()); + kn::TimeRecorder tc("CopyToGpu"); + kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); auto without_seal = tc.RecordSection("Without seal"); cpu_idx->Seal(); tc.RecordSection("seal cost"); - CopyCpuToGpu(cpu_idx, device_id, Config()); + kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); auto with_seal = tc.RecordSection("With seal"); ASSERT_GE(without_seal, with_seal); } - -class GPURESTEST - : public DataGen, public ::testing::Test { +class GPURESTEST : public DataGen, public ::testing::Test { protected: - void SetUp() override { + void + SetUp() override { Generate(128, 1000000, 1000); - FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024*1024*200, 1024*1024*300, 2); + kn::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024 * 1024 * 200, 1024 * 1024 * 300, 2); k = 100; elems = nq * k; - ids = (int64_t *) malloc(sizeof(int64_t) * elems); - dis = (float *) malloc(sizeof(float) * elems); + ids = (int64_t*)malloc(sizeof(int64_t) * elems); + dis = (float*)malloc(sizeof(float) * elems); } - void TearDown() override { + void + TearDown() override { delete ids; delete dis; - FaissGpuResourceMgr::GetInstance().Free(); + kn::FaissGpuResourceMgr::GetInstance().Free(); } protected: std::string index_type; - IVFIndexPtr index_ = nullptr; + kn::IVFIndexPtr index_ = nullptr; - int64_t *ids = nullptr; - float *dis = nullptr; + int64_t* ids = nullptr; + float* dis = nullptr; int64_t elems = 0; }; @@ -514,16 +514,16 @@ TEST_F(GPURESTEST, gpu_ivf_resource_test) { assert(!xb.empty()); { - index_ = std::make_shared(-1); - ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), -1); - std::dynamic_pointer_cast(index_)->SetGpuDevice(device_id); - ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), device_id); + index_ = std::make_shared(-1); + ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), -1); + std::dynamic_pointer_cast(index_)->SetGpuDevice(device_id); + ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), device_id); - auto conf = std::make_shared(); + auto conf = std::make_shared(); conf->nlist = 1638; conf->d = dim; conf->gpu_id = device_id; - conf->metric_type = METRICTYPE::L2; + conf->metric_type = kn::METRICTYPE::L2; conf->k = k; conf->nprobe = 1; @@ -535,7 +535,7 @@ TEST_F(GPURESTEST, gpu_ivf_resource_test) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dimension(), dim); - TimeRecorder tc("knowere GPUIVF"); + kn::TimeRecorder tc("knowere GPUIVF"); for (int i = 0; i < search_count; ++i) { index_->Search(query_dataset, conf); if (i > search_count - 6 || i < 5) @@ -543,7 +543,7 @@ TEST_F(GPURESTEST, gpu_ivf_resource_test) { } tc.ElapseFromBegin("search all"); } - FaissGpuResourceMgr::GetInstance().Dump(); + kn::FaissGpuResourceMgr::GetInstance().Dump(); { // IVF-Search @@ -554,7 +554,7 @@ TEST_F(GPURESTEST, gpu_ivf_resource_test) { device_index.train(nb, xb.data()); device_index.add(nb, xb.data()); - TimeRecorder tc("ori IVF"); + kn::TimeRecorder tc("ori IVF"); for (int i = 0; i < search_count; ++i) { device_index.search(nq, xq.data(), k, dis, ids); if (i > search_count - 6 || i < 5) @@ -562,7 +562,6 @@ TEST_F(GPURESTEST, gpu_ivf_resource_test) { } tc.ElapseFromBegin("search all"); } - } TEST_F(GPURESTEST, gpuivfsq) { @@ -571,11 +570,11 @@ TEST_F(GPURESTEST, gpuivfsq) { index_type = "GPUIVFSQ"; index_ = IndexFactory(index_type); - auto conf = std::make_shared(); + auto conf = std::make_shared(); conf->nlist = 1638; conf->d = dim; conf->gpu_id = device_id; - conf->metric_type = METRICTYPE::L2; + conf->metric_type = kn::METRICTYPE::L2; conf->k = k; conf->nbits = 8; conf->nprobe = 1; @@ -585,14 +584,14 @@ TEST_F(GPURESTEST, gpuivfsq) { auto model = index_->Train(base_dataset, conf); index_->set_index_model(model); index_->Add(base_dataset, conf); -// auto result = index_->Search(query_dataset, conf); -// AssertAnns(result, nq, k); + // auto result = index_->Search(query_dataset, conf); + // AssertAnns(result, nq, k); - auto cpu_idx = CopyGpuToCpu(index_, Config()); + auto cpu_idx = kn::cloner::CopyGpuToCpu(index_, kn::Config()); cpu_idx->Seal(); - TimeRecorder tc("knowhere GPUSQ8"); - auto search_idx = CopyCpuToGpu(cpu_idx, device_id, Config()); + kn::TimeRecorder tc("knowhere GPUSQ8"); + auto search_idx = kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); tc.RecordSection("Copy to gpu"); for (int i = 0; i < search_count; ++i) { search_idx->Search(query_dataset, conf); @@ -604,8 +603,8 @@ TEST_F(GPURESTEST, gpuivfsq) { { // Ori gpuivfsq Test - const char *index_description = "IVF1638,SQ8"; - faiss::Index *ori_index = faiss::index_factory(dim, index_description, faiss::METRIC_L2); + const char* index_description = "IVF1638,SQ8"; + faiss::Index* ori_index = faiss::index_factory(dim, index_description, faiss::METRIC_L2); faiss::gpu::StandardGpuResources res; auto device_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, ori_index); @@ -613,7 +612,7 @@ TEST_F(GPURESTEST, gpuivfsq) { device_index->add(nb, xb.data()); auto cpu_index = faiss::gpu::index_gpu_to_cpu(device_index); - auto idx = dynamic_cast(cpu_index); + auto idx = dynamic_cast(cpu_index); if (idx != nullptr) { idx->to_readonly(); } @@ -623,8 +622,8 @@ TEST_F(GPURESTEST, gpuivfsq) { faiss::gpu::GpuClonerOptions option; option.allInGpu = true; - TimeRecorder tc("ori GPUSQ8"); - faiss::Index *search_idx = faiss::gpu::index_cpu_to_gpu(&res, device_id, cpu_index, &option); + kn::TimeRecorder tc("ori GPUSQ8"); + faiss::Index* search_idx = faiss::gpu::index_cpu_to_gpu(&res, device_id, cpu_index, &option); tc.RecordSection("Copy to gpu"); for (int i = 0; i < search_count; ++i) { search_idx->search(nq, xq.data(), k, dis, ids); @@ -635,7 +634,6 @@ TEST_F(GPURESTEST, gpuivfsq) { delete cpu_index; delete search_idx; } - } TEST_F(GPURESTEST, copyandsearch) { @@ -645,11 +643,11 @@ TEST_F(GPURESTEST, copyandsearch) { index_type = "GPUIVFSQ"; index_ = IndexFactory(index_type); - auto conf = std::make_shared(); + auto conf = std::make_shared(); conf->nlist = 1638; conf->d = dim; conf->gpu_id = device_id; - conf->metric_type = METRICTYPE::L2; + conf->metric_type = kn::METRICTYPE::L2; conf->k = k; conf->nbits = 8; conf->nprobe = 1; @@ -659,35 +657,35 @@ TEST_F(GPURESTEST, copyandsearch) { auto model = index_->Train(base_dataset, conf); index_->set_index_model(model); index_->Add(base_dataset, conf); -// auto result = index_->Search(query_dataset, conf); -// AssertAnns(result, nq, k); + // auto result = index_->Search(query_dataset, conf); + // AssertAnns(result, nq, k); - auto cpu_idx = CopyGpuToCpu(index_, Config()); + auto cpu_idx = kn::cloner::CopyGpuToCpu(index_, kn::Config()); cpu_idx->Seal(); - auto search_idx = CopyCpuToGpu(cpu_idx, device_id, Config()); + auto search_idx = kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); auto search_func = [&] { - //TimeRecorder tc("search&load"); + // TimeRecorder tc("search&load"); for (int i = 0; i < search_count; ++i) { search_idx->Search(query_dataset, conf); - //if (i > search_count - 6 || i == 0) + // if (i > search_count - 6 || i == 0) // tc.RecordSection("search once"); } - //tc.ElapseFromBegin("search finish"); + // tc.ElapseFromBegin("search finish"); }; auto load_func = [&] { - //TimeRecorder tc("search&load"); + // TimeRecorder tc("search&load"); for (int i = 0; i < load_count; ++i) { - CopyCpuToGpu(cpu_idx, device_id, Config()); - //if (i > load_count -5 || i < 5) - //tc.RecordSection("Copy to gpu"); + kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); + // if (i > load_count -5 || i < 5) + // tc.RecordSection("Copy to gpu"); } - //tc.ElapseFromBegin("load finish"); + // tc.ElapseFromBegin("load finish"); }; - TimeRecorder tc("basic"); - CopyCpuToGpu(cpu_idx, device_id, Config()); + kn::TimeRecorder tc("basic"); + kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); tc.RecordSection("Copy to gpu once"); search_idx->Search(query_dataset, conf); tc.RecordSection("search once"); @@ -707,11 +705,11 @@ TEST_F(GPURESTEST, TrainAndSearch) { index_type = "GPUIVFSQ"; index_ = IndexFactory(index_type); - auto conf = std::make_shared(); + auto conf = std::make_shared(); conf->nlist = 1638; conf->d = dim; conf->gpu_id = device_id; - conf->metric_type = METRICTYPE::L2; + conf->metric_type = kn::METRICTYPE::L2; conf->k = k; conf->nbits = 8; conf->nprobe = 1; @@ -722,9 +720,9 @@ TEST_F(GPURESTEST, TrainAndSearch) { auto new_index = IndexFactory(index_type); new_index->set_index_model(model); new_index->Add(base_dataset, conf); - auto cpu_idx = CopyGpuToCpu(new_index, Config()); + auto cpu_idx = kn::cloner::CopyGpuToCpu(new_index, kn::Config()); cpu_idx->Seal(); - auto search_idx = CopyCpuToGpu(cpu_idx, device_id, Config()); + auto search_idx = kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); constexpr int train_count = 1; constexpr int search_count = 5000; @@ -736,18 +734,18 @@ TEST_F(GPURESTEST, TrainAndSearch) { test_idx->Add(base_dataset, conf); } }; - auto search_stage = [&](VectorIndexPtr& search_idx) { + auto search_stage = [&](kn::VectorIndexPtr& search_idx) { for (int i = 0; i < search_count; ++i) { auto result = search_idx->Search(query_dataset, conf); AssertAnns(result, nq, k); } }; - //TimeRecorder tc("record"); - //train_stage(); - //tc.RecordSection("train cost"); - //search_stage(search_idx); - //tc.RecordSection("search cost"); + // TimeRecorder tc("record"); + // train_stage(); + // tc.RecordSection("train cost"); + // search_stage(search_idx); + // tc.RecordSection("search cost"); { // search and build parallel @@ -765,7 +763,7 @@ TEST_F(GPURESTEST, TrainAndSearch) { } { // search parallel - auto search_idx_2 = CopyCpuToGpu(cpu_idx, device_id, Config()); + auto search_idx_2 = kn::cloner::CopyCpuToGpu(cpu_idx, device_id, kn::Config()); std::thread search_1(search_stage, std::ref(search_idx)); std::thread search_2(search_stage, std::ref(search_idx_2)); search_1.join(); @@ -773,6 +771,4 @@ TEST_F(GPURESTEST, TrainAndSearch) { } } - - -// TODO(linxj): Add exception test +// TODO(lxj): Add exception test diff --git a/cpp/src/core/test/test_json.cpp b/cpp/src/core/unittest/test_json.cpp similarity index 91% rename from cpp/src/core/test/test_json.cpp rename to cpp/src/core/unittest/test_json.cpp index b3c4fd4993c5a8aa87276ffaaa711fc56f704b76..613e0deff5dfdcfe847f256233822a113af993c1 100644 --- a/cpp/src/core/test/test_json.cpp +++ b/cpp/src/core/unittest/test_json.cpp @@ -15,13 +15,17 @@ // specific language governing permissions and limitations // under the License. - #include "knowhere/common/config.h" -using namespace zilliz::knowhere; +namespace { + +namespace kn = knowhere; + +} // namespace -int main(){ - Config cfg; +int +main() { + kn::Config cfg; cfg["size"] = size_t(199); auto size = cfg.get_with_default("size", 123); diff --git a/cpp/src/core/test/test_kdt.cpp b/cpp/src/core/unittest/test_kdt.cpp similarity index 80% rename from cpp/src/core/test/test_kdt.cpp rename to cpp/src/core/unittest/test_kdt.cpp index fce233aca7f0272422e9f69dd4e4b90dd50a5c95..f9e02bd9a4cebcebf86934633fc80e36b5f4b40d 100644 --- a/cpp/src/core/test/test_kdt.cpp +++ b/cpp/src/core/unittest/test_kdt.cpp @@ -15,35 +15,36 @@ // specific language governing permissions and limitations // under the License. - #include #include #include +#include "knowhere/adapter/SptagAdapter.h" +#include "knowhere/adapter/Structure.h" #include "knowhere/common/Exception.h" #include "knowhere/index/vector_index/IndexKDT.h" #include "knowhere/index/vector_index/helpers/Definitions.h" -#include "knowhere/adapter/SptagAdapter.h" -#include "knowhere/adapter/Structure.h" -#include "utils.h" +#include "unittest/utils.h" +namespace { -using namespace zilliz::knowhere; +namespace kn = knowhere; +} // namespace + +using ::testing::Combine; using ::testing::TestWithParam; using ::testing::Values; -using ::testing::Combine; - -class KDTTest - : public DataGen, public ::testing::Test { +class KDTTest : public DataGen, public ::testing::Test { protected: - void SetUp() override { - index_ = std::make_shared(); + void + SetUp() override { + index_ = std::make_shared(); - auto tempconf = std::make_shared(); + auto tempconf = std::make_shared(); tempconf->tptnubmber = 1; tempconf->k = 10; conf = tempconf; @@ -52,22 +53,20 @@ class KDTTest } protected: - Config conf; - std::shared_ptr index_ = nullptr; + kn::Config conf; + std::shared_ptr index_ = nullptr; }; -void AssertAnns(const DatasetPtr &result, - const int &nq, - const int &k) { +void +AssertAnns(const kn::DatasetPtr& result, const int& nq, const int& k) { auto ids = result->array()[0]; for (auto i = 0; i < nq; i++) { EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); } } -void PrintResult(const DatasetPtr &result, - const int &nq, - const int &k) { +void +PrintResult(const kn::DatasetPtr& result, const int& nq, const int& k) { auto ids = result->array()[0]; auto dists = result->array()[1]; @@ -85,7 +84,7 @@ void PrintResult(const DatasetPtr &result, std::cout << "dist\n" << ss_dist.str() << std::endl; } -// TODO(linxj): add test about count() and dimension() +// TODO(lxj): add test about count() and dimension() TEST_F(KDTTest, kdt_basic) { assert(!xb.empty()); @@ -124,25 +123,25 @@ TEST_F(KDTTest, kdt_serialize) { index_->set_preprocessor(preprocessor); auto model = index_->Train(base_dataset, conf); - //index_->Add(base_dataset, conf); + // index_->Add(base_dataset, conf); auto binaryset = index_->Serialize(); - auto new_index = std::make_shared(); + auto new_index = std::make_shared(); new_index->Load(binaryset); auto result = new_index->Search(query_dataset, conf); AssertAnns(result, nq, k); PrintResult(result, nq, k); ASSERT_EQ(new_index->Count(), nb); ASSERT_EQ(new_index->Dimension(), dim); - ASSERT_THROW({new_index->Clone();}, zilliz::knowhere::KnowhereException); - ASSERT_NO_THROW({new_index->Seal();}); + ASSERT_THROW({ new_index->Clone(); }, knowhere::KnowhereException); + ASSERT_NO_THROW({ new_index->Seal(); }); { int fileno = 0; - const std::string &base_name = "/tmp/kdt_serialize_test_bin_"; + const std::string& base_name = "/tmp/kdt_serialize_test_bin_"; std::vector filename_list; - std::vector> meta_list; - for (auto &iter: binaryset.binary_map_) { - const std::string &filename = base_name + std::to_string(fileno); + std::vector> meta_list; + for (auto& iter : binaryset.binary_map_) { + const std::string& filename = base_name + std::to_string(fileno); FileIOWriter writer(filename); writer(iter.second->data.get(), iter.second->size); @@ -151,7 +150,7 @@ TEST_F(KDTTest, kdt_serialize) { ++fileno; } - BinarySet load_data_list; + kn::BinarySet load_data_list; for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) { auto bin_size = meta_list[i].second; FileIOReader reader(filename_list[i]); @@ -163,7 +162,7 @@ TEST_F(KDTTest, kdt_serialize) { load_data_list.Append(meta_list[i].first, data, bin_size); } - auto new_index = std::make_shared(); + auto new_index = std::make_shared(); new_index->Load(load_data_list); auto result = new_index->Search(query_dataset, conf); AssertAnns(result, nq, k); diff --git a/cpp/src/core/test/test_nsg.cpp b/cpp/src/core/unittest/test_nsg.cpp similarity index 73% rename from cpp/src/core/test/test_nsg.cpp rename to cpp/src/core/unittest/test_nsg.cpp index f15692502781f0fd026ef8717ef4b999741daaa2..a4b6006e719a688a1d264101f92c4e7fd0dccb8d 100644 --- a/cpp/src/core/test/test_nsg.cpp +++ b/cpp/src/core/unittest/test_nsg.cpp @@ -15,46 +15,46 @@ // specific language governing permissions and limitations // under the License. - #include #include -#include -#include "index.h" +#include "knowhere/index/index.h" +#include "test/utils.h" //#include -using namespace zilliz::knowhere; - -void load_data(std::string &filename, float *&data, unsigned &num, - unsigned &dim) { // load data with sift10K pattern +void +load_data(std::string& filename, float*& data, unsigned& num, + unsigned& dim) { // load data with sift10K pattern std::ifstream in(filename, std::ios::binary); if (!in.is_open()) { std::cout << "open file error" << std::endl; exit(-1); } - in.read((char *) &dim, 4); + in.read((char*)&dim, 4); in.seekg(0, std::ios::end); std::ios::pos_type ss = in.tellg(); - size_t fsize = (size_t) ss; - num = (unsigned) (fsize / (dim + 1) / 4); - data = new float[(size_t) num * (size_t) dim]; + size_t fsize = (size_t)ss; + num = (unsigned)(fsize / (dim + 1) / 4); + data = new float[(size_t)num * (size_t)dim]; in.seekg(0, std::ios::beg); for (size_t i = 0; i < num; i++) { in.seekg(4, std::ios::cur); - in.read((char *) (data + i * dim), dim * 4); + in.read((char*)(data + i * dim), dim * 4); } in.close(); } -void test_distance() { +void +test_distance() { std::vector xb{1, 2, 3, 4}; std::vector xq{2, 2, 3, 4}; float r = calculate(xb.data(), xq.data(), 4); std::cout << r << std::endl; } -int main() { +int +main() { test_distance(); BuildParams params; @@ -62,16 +62,16 @@ int main() { params.candidate_pool_size = 100; params.out_degree = 50; - float *data = nullptr; - long *ids = nullptr; + float* data = nullptr; + int64_t* ids = nullptr; unsigned ntotal, dim; std::string filename = "/home/zilliz/opt/workspace/wook/efanna_graph/tests/siftsmall/siftsmall_base.fvecs"; - //std::string filename = "/home/zilliz/opt/workspace/wook/efanna_graph/tests/sift/sift_base.fvecs"; + // std::string filename = "/home/zilliz/opt/workspace/wook/efanna_graph/tests/sift/sift_base.fvecs"; load_data(filename, data, ntotal, dim); assert(data); - //float x = calculate(data + dim * 0, data + dim * 62, dim); - //std::cout << x << std::endl; + // float x = calculate(data + dim * 0, data + dim * 62, dim); + // std::cout << x << std::endl; NsgIndex index(dim, ntotal); @@ -81,24 +81,23 @@ int main() { std::chrono::duration diff = e - s; std::cout << "indexing time: " << diff.count() << "\n"; - int k = 10; int nq = 1000; SearchParams s_params; s_params.search_length = 50; - auto dist = new float[nq*k]; - auto ids_b = new long[nq*k]; + auto dist = new float[nq * k]; + auto ids_b = new int64_t[nq * k]; s = std::chrono::high_resolution_clock::now(); - //ProfilerStart("xx.prof"); + // ProfilerStart("xx.prof"); index.Search(data, nq, dim, k, dist, ids_b, s_params); - //ProfilerStop(); + // ProfilerStop(); e = std::chrono::high_resolution_clock::now(); diff = e - s; std::cout << "search time: " << diff.count() << "\n"; for (int i = 0; i < k; ++i) { std::cout << "id " << ids_b[i] << std::endl; - //std::cout << "dist " << dist[i] << std::endl; + // std::cout << "dist " << dist[i] << std::endl; } delete[] dist; @@ -106,5 +105,3 @@ int main() { return 0; } - - diff --git a/cpp/src/core/unittest/test_nsg/CMakeLists.txt b/cpp/src/core/unittest/test_nsg/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..393cc188b09a370545f3c0d30529accb15750d65 --- /dev/null +++ b/cpp/src/core/unittest/test_nsg/CMakeLists.txt @@ -0,0 +1,30 @@ +############################## +#include_directories(/usr/local/include/gperftools) +#link_directories(/usr/local/lib) + +add_definitions(-std=c++11 -O3 -lboost -march=native -Wall -DINFO) + +find_package(OpenMP) +if (OPENMP_FOUND) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +else () + message(FATAL_ERROR "no OpenMP supprot") +endif () +message(${OpenMP_CXX_FLAGS}) + +include_directories(${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/nsg) +aux_source_directory(${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/nsg nsg_src) + +set(interface_src + ${CORE_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNSG.cpp + ) + +if(NOT TARGET test_nsg) + add_executable(test_nsg test_nsg.cpp ${interface_src} ${nsg_src} ${util_srcs} ${ivf_srcs}) +endif() + +target_link_libraries(test_nsg ${depend_libs} ${unittest_libs} ${basic_libs}) +############################## + +install(TARGETS test_nsg DESTINATION unittest) \ No newline at end of file diff --git a/cpp/src/core/unittest/test_nsg/test_nsg.cpp b/cpp/src/core/unittest/test_nsg/test_nsg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b59f0d49284bf1c9b01e1ea78c37058276868788 --- /dev/null +++ b/cpp/src/core/unittest/test_nsg/test_nsg.cpp @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/FaissBaseIndex.h" +#include "knowhere/index/vector_index/IndexNSG.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/index/vector_index/nsg/NSGIO.h" + +#include "unittest/utils.h" + +namespace { + +namespace kn = knowhere; + +} // namespace + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr int64_t DEVICE_ID = 1; + +class NSGInterfaceTest : public DataGen, public ::testing::Test { + protected: + void + SetUp() override { + // Init_with_default(); + kn::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, 1024 * 1024 * 200, 1024 * 1024 * 600, 2); + Generate(256, 1000000, 1); + index_ = std::make_shared(); + + auto tmp_conf = std::make_shared(); + tmp_conf->gpu_id = DEVICE_ID; + tmp_conf->knng = 100; + tmp_conf->nprobe = 32; + tmp_conf->nlist = 16384; + tmp_conf->search_length = 60; + tmp_conf->out_degree = 70; + tmp_conf->candidate_pool_size = 500; + tmp_conf->metric_type = kn::METRICTYPE::L2; + train_conf = tmp_conf; + + auto tmp2_conf = std::make_shared(); + tmp2_conf->k = k; + tmp2_conf->search_length = 30; + search_conf = tmp2_conf; + } + + void + TearDown() override { + kn::FaissGpuResourceMgr::GetInstance().Free(); + } + + protected: + std::shared_ptr index_; + kn::Config train_conf; + kn::Config search_conf; +}; + +void +AssertAnns(const kn::DatasetPtr& result, const int& nq, const int& k) { + auto ids = result->array()[0]; + for (auto i = 0; i < nq; i++) { + EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); + } +} + +TEST_F(NSGInterfaceTest, basic_test) { + assert(!xb.empty()); + + auto model = index_->Train(base_dataset, train_conf); + auto result = index_->Search(query_dataset, search_conf); + AssertAnns(result, nq, k); + + auto binaryset = index_->Serialize(); + auto new_index = std::make_shared(); + new_index->Load(binaryset); + auto new_result = new_index->Search(query_dataset, search_conf); + AssertAnns(result, nq, k); + + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dimension(), dim); + ASSERT_THROW({ index_->Clone(); }, knowhere::KnowhereException); + ASSERT_NO_THROW({ + index_->Add(base_dataset, kn::Config()); + index_->Seal(); + }); + + { + // std::cout << "k = 1" << std::endl; + // new_index->Search(GenQuery(1), Config::object{{"k", 1}}); + // new_index->Search(GenQuery(10), Config::object{{"k", 1}}); + // new_index->Search(GenQuery(100), Config::object{{"k", 1}}); + // new_index->Search(GenQuery(1000), Config::object{{"k", 1}}); + // new_index->Search(GenQuery(10000), Config::object{{"k", 1}}); + + // std::cout << "k = 5" << std::endl; + // new_index->Search(GenQuery(1), Config::object{{"k", 5}}); + // new_index->Search(GenQuery(20), Config::object{{"k", 5}}); + // new_index->Search(GenQuery(100), Config::object{{"k", 5}}); + // new_index->Search(GenQuery(300), Config::object{{"k", 5}}); + // new_index->Search(GenQuery(500), Config::object{{"k", 5}}); + } +} diff --git a/cpp/src/core/unittest/utils.cpp b/cpp/src/core/unittest/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c342738c1e2959a768cc5bac27ed9fe287e2767 --- /dev/null +++ b/cpp/src/core/unittest/utils.cpp @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "unittest/utils.h" + +#include +#include +#include + +INITIALIZE_EASYLOGGINGPP + +namespace { + +namespace kn = knowhere; + +} // namespace + +void +InitLog() { + el::Configurations defaultConf; + defaultConf.setToDefault(); + defaultConf.set(el::Level::Debug, el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)"); + el::Loggers::reconfigureLogger("default", defaultConf); +} + +void +DataGen::Init_with_default() { + Generate(dim, nb, nq); +} + +void +DataGen::Generate(const int& dim, const int& nb, const int& nq) { + this->nb = nb; + this->nq = nq; + this->dim = dim; + + GenAll(dim, nb, xb, ids, nq, xq); + assert(xb.size() == (size_t)dim * nb); + assert(xq.size() == (size_t)dim * nq); + + base_dataset = generate_dataset(nb, dim, xb.data(), ids.data()); + query_dataset = generate_query_dataset(nq, dim, xq.data()); +} + +knowhere::DatasetPtr +DataGen::GenQuery(const int& nq) { + xq.resize(nq * dim); + for (int i = 0; i < nq * dim; ++i) { + xq[i] = xb[i]; + } + return generate_query_dataset(nq, dim, xq.data()); +} + +void +GenAll(const int64_t dim, const int64_t& nb, std::vector& xb, std::vector& ids, const int64_t& nq, + std::vector& xq) { + xb.resize(nb * dim); + xq.resize(nq * dim); + ids.resize(nb); + GenAll(dim, nb, xb.data(), ids.data(), nq, xq.data()); +} + +void +GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int64_t& nq, float* xq) { + GenBase(dim, nb, xb, ids); + for (int64_t i = 0; i < nq * dim; ++i) { + xq[i] = xb[i]; + } +} + +void +GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids) { + for (auto i = 0; i < nb; ++i) { + for (auto j = 0; j < dim; ++j) { + // p_data[i * d + j] = float(base + i); + xb[i * dim + j] = drand48(); + } + xb[dim * i] += i / 1000.; + ids[i] = i; + } +} + +FileIOReader::FileIOReader(const std::string& fname) { + name = fname; + fs = std::fstream(name, std::ios::in | std::ios::binary); +} + +FileIOReader::~FileIOReader() { + fs.close(); +} + +size_t +FileIOReader::operator()(void* ptr, size_t size) { + fs.read(reinterpret_cast(ptr), size); + return size; +} + +FileIOWriter::FileIOWriter(const std::string& fname) { + name = fname; + fs = std::fstream(name, std::ios::out | std::ios::binary); +} + +FileIOWriter::~FileIOWriter() { + fs.close(); +} + +size_t +FileIOWriter::operator()(void* ptr, size_t size) { + fs.write(reinterpret_cast(ptr), size); + return size; +} + +kn::DatasetPtr +generate_dataset(int64_t nb, int64_t dim, float* xb, int64_t* ids) { + std::vector shape{nb, dim}; + auto tensor = kn::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape); + std::vector tensors{tensor}; + std::vector tensor_fields{kn::ConstructFloatField("data")}; + auto tensor_schema = std::make_shared(tensor_fields); + + auto id_array = kn::ConstructInt64Array((uint8_t*)ids, nb * sizeof(int64_t)); + std::vector arrays{id_array}; + std::vector array_fields{kn::ConstructInt64Field("id")}; + auto array_schema = std::make_shared(tensor_fields); + + auto dataset = std::make_shared(std::move(arrays), array_schema, std::move(tensors), tensor_schema); + return dataset; +} + +kn::DatasetPtr +generate_query_dataset(int64_t nb, int64_t dim, float* xb) { + std::vector shape{nb, dim}; + auto tensor = kn::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape); + std::vector tensors{tensor}; + std::vector tensor_fields{kn::ConstructFloatField("data")}; + auto tensor_schema = std::make_shared(tensor_fields); + + auto dataset = std::make_shared(std::move(tensors), tensor_schema); + return dataset; +} diff --git a/cpp/src/core/test/utils.h b/cpp/src/core/unittest/utils.h similarity index 51% rename from cpp/src/core/test/utils.h rename to cpp/src/core/unittest/utils.h index 41dbbc1f30118978f1c9d1afd6ab9e25f969c12a..acc3e89183972190955639b70c19da3b8c15272c 100644 --- a/cpp/src/core/test/utils.h +++ b/cpp/src/core/unittest/utils.h @@ -15,24 +15,27 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include -#include #include +#include #include +#include +#include #include "knowhere/adapter/Structure.h" #include "knowhere/common/Log.h" -class DataGen { +class DataGen { protected: - void Init_with_default(); + void + Init_with_default(); - void Generate(const int &dim, const int &nb, const int &nq); + void + Generate(const int& dim, const int& nb, const int& nq); - zilliz::knowhere::DatasetPtr GenQuery(const int&nq); + knowhere::DatasetPtr + GenQuery(const int& nq); protected: int nb = 10000; @@ -42,53 +45,45 @@ class DataGen { std::vector xb; std::vector xq; std::vector ids; - zilliz::knowhere::DatasetPtr base_dataset = nullptr; - zilliz::knowhere::DatasetPtr query_dataset = nullptr; + knowhere::DatasetPtr base_dataset = nullptr; + knowhere::DatasetPtr query_dataset = nullptr; }; +extern void +GenAll(const int64_t dim, const int64_t& nb, std::vector& xb, std::vector& ids, const int64_t& nq, + std::vector& xq); -extern void GenAll(const int64_t dim, - const int64_t &nb, - std::vector &xb, - std::vector &ids, - const int64_t &nq, - std::vector &xq); - -extern void GenAll(const int64_t &dim, - const int64_t &nb, - float *xb, - int64_t *ids, - const int64_t &nq, - float *xq); +extern void +GenAll(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids, const int64_t& nq, float* xq); -extern void GenBase(const int64_t &dim, - const int64_t &nb, - float *xb, - int64_t *ids); +extern void +GenBase(const int64_t& dim, const int64_t& nb, float* xb, int64_t* ids); -extern void InitLog(); +extern void +InitLog(); -zilliz::knowhere::DatasetPtr -generate_dataset(int64_t nb, int64_t dim, float *xb, long *ids); +knowhere::DatasetPtr +generate_dataset(int64_t nb, int64_t dim, float* xb, int64_t* ids); -zilliz::knowhere::DatasetPtr -generate_query_dataset(int64_t nb, int64_t dim, float *xb); +knowhere::DatasetPtr +generate_query_dataset(int64_t nb, int64_t dim, float* xb); struct FileIOWriter { std::fstream fs; std::string name; - FileIOWriter(const std::string &fname); + explicit FileIOWriter(const std::string& fname); ~FileIOWriter(); - size_t operator()(void *ptr, size_t size); + size_t + operator()(void* ptr, size_t size); }; struct FileIOReader { std::fstream fs; std::string name; - FileIOReader(const std::string &fname); + explicit FileIOReader(const std::string& fname); ~FileIOReader(); - size_t operator()(void *ptr, size_t size); + size_t + operator()(void* ptr, size_t size); }; - diff --git a/cpp/src/db/Constants.h b/cpp/src/db/Constants.h index 2beb3c3a9792acba2c5e3b083eccc75c5f689e73..a1e09bc196d3df61eec9ccbd3df2fd380e996e40 100644 --- a/cpp/src/db/Constants.h +++ b/cpp/src/db/Constants.h @@ -19,7 +19,6 @@ #include -namespace zilliz { namespace milvus { namespace engine { @@ -36,6 +35,5 @@ static constexpr uint64_t ONE_KB = K; static constexpr uint64_t ONE_MB = ONE_KB * ONE_KB; static constexpr uint64_t ONE_GB = ONE_KB * ONE_MB; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/DB.h b/cpp/src/db/DB.h index 127812fc301e112c0da082e8a06bb6a1b5c9981b..a790fadb502f3f1d93d879873daa8a5adb228c44 100644 --- a/cpp/src/db/DB.h +++ b/cpp/src/db/DB.h @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "Options.h" @@ -23,11 +22,10 @@ #include "meta/Meta.h" #include "utils/Status.h" -#include #include +#include #include -namespace zilliz { namespace milvus { namespace engine { @@ -36,47 +34,64 @@ class Env; class DB { public: DB() = default; - DB(const DB &) = delete; - DB &operator=(const DB &) = delete; + DB(const DB&) = delete; + DB& + operator=(const DB&) = delete; virtual ~DB() = default; - virtual Status Start() = 0; - virtual Status Stop() = 0; - - virtual Status CreateTable(meta::TableSchema &table_schema_) = 0; - virtual Status DeleteTable(const std::string &table_id, const meta::DatesT &dates) = 0; - virtual Status DescribeTable(meta::TableSchema &table_schema_) = 0; - virtual Status HasTable(const std::string &table_id, bool &has_or_not_) = 0; - virtual Status AllTables(std::vector &table_schema_array) = 0; - virtual Status GetTableRowCount(const std::string &table_id, uint64_t &row_count) = 0; - virtual Status PreloadTable(const std::string &table_id) = 0; - virtual Status UpdateTableFlag(const std::string &table_id, int64_t flag) = 0; - - virtual Status InsertVectors(const std::string &table_id_, - uint64_t n, const float *vectors, IDNumbers &vector_ids_) = 0; - - virtual Status Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe, - const float *vectors, QueryResults &results) = 0; - - virtual Status Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe, - const float *vectors, const meta::DatesT &dates, QueryResults &results) = 0; - - virtual Status Query(const std::string &table_id, const std::vector &file_ids, - uint64_t k, uint64_t nq, uint64_t nprobe, const float *vectors, - const meta::DatesT &dates, QueryResults &results) = 0; - - virtual Status Size(uint64_t &result) = 0; - - virtual Status CreateIndex(const std::string &table_id, const TableIndex &index) = 0; - virtual Status DescribeIndex(const std::string &table_id, TableIndex &index) = 0; - virtual Status DropIndex(const std::string &table_id) = 0; - - virtual Status DropAll() = 0; -}; // DB + virtual Status + Start() = 0; + virtual Status + Stop() = 0; + + virtual Status + CreateTable(meta::TableSchema& table_schema_) = 0; + virtual Status + DeleteTable(const std::string& table_id, const meta::DatesT& dates) = 0; + virtual Status + DescribeTable(meta::TableSchema& table_schema_) = 0; + virtual Status + HasTable(const std::string& table_id, bool& has_or_not_) = 0; + virtual Status + AllTables(std::vector& table_schema_array) = 0; + virtual Status + GetTableRowCount(const std::string& table_id, uint64_t& row_count) = 0; + virtual Status + PreloadTable(const std::string& table_id) = 0; + virtual Status + UpdateTableFlag(const std::string& table_id, int64_t flag) = 0; + + virtual Status + InsertVectors(const std::string& table_id_, uint64_t n, const float* vectors, IDNumbers& vector_ids_) = 0; + + virtual Status + Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, + QueryResults& results) = 0; + + virtual Status + Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, + const meta::DatesT& dates, QueryResults& results) = 0; + + virtual Status + Query(const std::string& table_id, const std::vector& file_ids, uint64_t k, uint64_t nq, + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) = 0; + + virtual Status + Size(uint64_t& result) = 0; + + virtual Status + CreateIndex(const std::string& table_id, const TableIndex& index) = 0; + virtual Status + DescribeIndex(const std::string& table_id, TableIndex& index) = 0; + virtual Status + DropIndex(const std::string& table_id) = 0; + + virtual Status + DropAll() = 0; +}; // DB using DBPtr = std::shared_ptr; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/DBFactory.cpp b/cpp/src/db/DBFactory.cpp index edf70cccf82c0e333286b61598f97a72e55e667f..fae3a180c07b00b9261dca4dbdfea4344178dc43 100644 --- a/cpp/src/db/DBFactory.cpp +++ b/cpp/src/db/DBFactory.cpp @@ -15,21 +15,19 @@ // specific language governing permissions and limitations // under the License. - #include "db/DBFactory.h" #include "DBImpl.h" -#include "utils/Exception.h" #include "meta/MetaFactory.h" -#include "meta/SqliteMetaImpl.h" #include "meta/MySQLMetaImpl.h" +#include "meta/SqliteMetaImpl.h" +#include "utils/Exception.h" #include #include -#include #include +#include #include -namespace zilliz { namespace milvus { namespace engine { @@ -42,10 +40,9 @@ DBFactory::BuildOption() { } DBPtr -DBFactory::Build(const DBOptions &options) { +DBFactory::Build(const DBOptions& options) { return std::make_shared(options); } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/DBFactory.h b/cpp/src/db/DBFactory.h index 1223a11c2907d0fe8cbba9712b17065497a8838d..e787f7c8839b414b192e4ab72a7677df29f56cd6 100644 --- a/cpp/src/db/DBFactory.h +++ b/cpp/src/db/DBFactory.h @@ -20,20 +20,20 @@ #include "DB.h" #include "Options.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace engine { class DBFactory { public: - static DBOptions BuildOption(); + static DBOptions + BuildOption(); - static DBPtr Build(const DBOptions &options); + static DBPtr + Build(const DBOptions& options); }; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index ab3f3104af3bc10e6be3e65a077be93c002e72c2..f81fb32e7fde14f7614dc7f263f721ba6d8f21b0 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -16,30 +16,30 @@ // under the License. #include "db/DBImpl.h" +#include "Utils.h" #include "cache/CpuCacheMgr.h" #include "cache/GpuCacheMgr.h" #include "engine/EngineFactory.h" #include "insert/MemMenagerFactory.h" -#include "meta/SqliteMetaImpl.h" -#include "meta/MetaFactory.h" #include "meta/MetaConsts.h" +#include "meta/MetaFactory.h" +#include "meta/SqliteMetaImpl.h" #include "metrics/Metrics.h" -#include "scheduler/job/SearchJob.h" -#include "scheduler/job/DeleteJob.h" #include "scheduler/SchedInst.h" -#include "utils/TimeRecorder.h" +#include "scheduler/job/BuildIndexJob.h" +#include "scheduler/job/DeleteJob.h" +#include "scheduler/job/SearchJob.h" #include "utils/Log.h" -#include "Utils.h" +#include "utils/TimeRecorder.h" #include -#include -#include -#include -#include #include #include +#include +#include +#include +#include -namespace zilliz { namespace milvus { namespace engine { @@ -49,13 +49,10 @@ constexpr uint64_t METRIC_ACTION_INTERVAL = 1; constexpr uint64_t COMPACT_ACTION_INTERVAL = 1; constexpr uint64_t INDEX_ACTION_INTERVAL = 1; -} // namespace +} // namespace -DBImpl::DBImpl(const DBOptions &options) - : options_(options), - shutting_down_(true), - compact_thread_pool_(1, 1), - index_thread_pool_(1, 1) { +DBImpl::DBImpl(const DBOptions& options) + : options_(options), shutting_down_(true), compact_thread_pool_(1, 1), index_thread_pool_(1, 1) { meta_ptr_ = MetaFactory::Build(options.meta_, options.mode_); mem_mgr_ = MemManagerFactory::Build(meta_ptr_, options_); Start(); @@ -66,7 +63,7 @@ DBImpl::~DBImpl() { } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -//external api +// external api /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// Status DBImpl::Start() { @@ -77,7 +74,7 @@ DBImpl::Start() { ENGINE_LOG_TRACE << "DB service start"; shutting_down_.store(false, std::memory_order_release); - //for distribute version, some nodes are read only + // for distribute version, some nodes are read only if (options_.mode_ != DBOptions::MODE::CLUSTER_READONLY) { ENGINE_LOG_TRACE << "StartTimerTasks"; bg_timer_thread_ = std::thread(&DBImpl::BackgroundTimerTask, this); @@ -94,10 +91,10 @@ DBImpl::Stop() { shutting_down_.store(true, std::memory_order_release); - //makesure all memory data serialized + // makesure all memory data serialized MemSerialize(); - //wait compaction/buildindex finish + // wait compaction/buildindex finish bg_timer_thread_.join(); if (options_.mode_ != DBOptions::MODE::CLUSTER_READONLY) { @@ -114,30 +111,30 @@ DBImpl::DropAll() { } Status -DBImpl::CreateTable(meta::TableSchema &table_schema) { +DBImpl::CreateTable(meta::TableSchema& table_schema) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } meta::TableSchema temp_schema = table_schema; - temp_schema.index_file_size_ *= ONE_MB; //store as MB + temp_schema.index_file_size_ *= ONE_MB; // store as MB return meta_ptr_->CreateTable(temp_schema); } Status -DBImpl::DeleteTable(const std::string &table_id, const meta::DatesT &dates) { +DBImpl::DeleteTable(const std::string& table_id, const meta::DatesT& dates) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } - //dates partly delete files of the table but currently we don't support + // dates partly delete files of the table but currently we don't support ENGINE_LOG_DEBUG << "Prepare to delete table " << table_id; if (dates.empty()) { - mem_mgr_->EraseMemVector(table_id); //not allow insert - meta_ptr_->DeleteTable(table_id); //soft delete table + mem_mgr_->EraseMemVector(table_id); // not allow insert + meta_ptr_->DeleteTable(table_id); // soft delete table - //scheduler will determine when to delete table files + // scheduler will determine when to delete table files auto nres = scheduler::ResMgrInst::GetInstance()->GetNumOfComputeResource(); scheduler::DeleteJobPtr job = std::make_shared(0, table_id, meta_ptr_, nres); scheduler::JobMgrInst::GetInstance()->Put(job); @@ -150,18 +147,18 @@ DBImpl::DeleteTable(const std::string &table_id, const meta::DatesT &dates) { } Status -DBImpl::DescribeTable(meta::TableSchema &table_schema) { +DBImpl::DescribeTable(meta::TableSchema& table_schema) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } auto stat = meta_ptr_->DescribeTable(table_schema); - table_schema.index_file_size_ /= ONE_MB; //return as MB + table_schema.index_file_size_ /= ONE_MB; // return as MB return stat; } Status -DBImpl::HasTable(const std::string &table_id, bool &has_or_not) { +DBImpl::HasTable(const std::string& table_id, bool& has_or_not) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -170,7 +167,7 @@ DBImpl::HasTable(const std::string &table_id, bool &has_or_not) { } Status -DBImpl::AllTables(std::vector &table_schema_array) { +DBImpl::AllTables(std::vector& table_schema_array) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -179,7 +176,7 @@ DBImpl::AllTables(std::vector &table_schema_array) { } Status -DBImpl::PreloadTable(const std::string &table_id) { +DBImpl::PreloadTable(const std::string& table_id) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -198,13 +195,11 @@ DBImpl::PreloadTable(const std::string &table_id) { int64_t cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); int64_t available_size = cache_total - cache_usage; - for (auto &day_files : files) { - for (auto &file : day_files.second) { - ExecutionEnginePtr engine = EngineFactory::Build(file.dimension_, - file.location_, - (EngineType) file.engine_type_, - (MetricType) file.metric_type_, - file.nlist_); + for (auto& day_files : files) { + for (auto& file : day_files.second) { + ExecutionEnginePtr engine = + EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_, + (MetricType)file.metric_type_, file.nlist_); if (engine == nullptr) { ENGINE_LOG_ERROR << "Invalid engine type"; return Status(DB_ERROR, "Invalid engine type"); @@ -212,12 +207,12 @@ DBImpl::PreloadTable(const std::string &table_id) { size += engine->PhysicalSize(); if (size > available_size) { - break; + return Status(SERVER_CACHE_FULL, "Cache is full"); } else { try { - //step 1: load index + // step 1: load index engine->Load(true); - } catch (std::exception &ex) { + } catch (std::exception& ex) { std::string msg = "Pre-load table encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; return Status(DB_ERROR, msg); @@ -229,7 +224,7 @@ DBImpl::PreloadTable(const std::string &table_id) { } Status -DBImpl::UpdateTableFlag(const std::string &table_id, int64_t flag) { +DBImpl::UpdateTableFlag(const std::string& table_id, int64_t flag) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -238,7 +233,7 @@ DBImpl::UpdateTableFlag(const std::string &table_id, int64_t flag) { } Status -DBImpl::GetTableRowCount(const std::string &table_id, uint64_t &row_count) { +DBImpl::GetTableRowCount(const std::string& table_id, uint64_t& row_count) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -247,31 +242,30 @@ DBImpl::GetTableRowCount(const std::string &table_id, uint64_t &row_count) { } Status -DBImpl::InsertVectors(const std::string &table_id_, - uint64_t n, const float *vectors, IDNumbers &vector_ids_) { -// ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache"; +DBImpl::InsertVectors(const std::string& table_id, uint64_t n, const float* vectors, IDNumbers& vector_ids) { + // ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache"; if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } Status status; - zilliz::milvus::server::CollectInsertMetrics metrics(n, status); - status = mem_mgr_->InsertVectors(table_id_, n, vectors, vector_ids_); -// std::chrono::microseconds time_span = -// std::chrono::duration_cast(end_time - start_time); -// double average_time = double(time_span.count()) / n; + milvus::server::CollectInsertMetrics metrics(n, status); + status = mem_mgr_->InsertVectors(table_id, n, vectors, vector_ids); + // std::chrono::microseconds time_span = + // std::chrono::duration_cast(end_time - start_time); + // double average_time = double(time_span.count()) / n; -// ENGINE_LOG_DEBUG << "Insert vectors to cache finished"; + // ENGINE_LOG_DEBUG << "Insert vectors to cache finished"; return status; } Status -DBImpl::CreateIndex(const std::string &table_id, const TableIndex &index) { +DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index) { { std::unique_lock lock(build_index_mutex_); - //step 1: check index difference + // step 1: check index difference TableIndex old_index; auto status = DescribeIndex(table_id, old_index); if (!status.ok()) { @@ -279,9 +273,9 @@ DBImpl::CreateIndex(const std::string &table_id, const TableIndex &index) { return status; } - //step 2: update index info + // step 2: update index info TableIndex new_index = index; - new_index.metric_type_ = old_index.metric_type_;//dont change metric type, it was defined by CreateTable + new_index.metric_type_ = old_index.metric_type_; // dont change metric type, it was defined by CreateTable if (!utils::IsSameIndex(old_index, new_index)) { DropIndex(table_id); @@ -293,26 +287,26 @@ DBImpl::CreateIndex(const std::string &table_id, const TableIndex &index) { } } - //step 3: let merge file thread finish - //to avoid duplicate data bug + // step 3: let merge file thread finish + // to avoid duplicate data bug WaitMergeFileFinish(); - //step 4: wait and build index - //for IDMAP type, only wait all NEW file converted to RAW file - //for other type, wait NEW/RAW/NEW_MERGE/NEW_INDEX/TO_INDEX files converted to INDEX files + // step 4: wait and build index + // for IDMAP type, only wait all NEW file converted to RAW file + // for other type, wait NEW/RAW/NEW_MERGE/NEW_INDEX/TO_INDEX files converted to INDEX files std::vector file_types; - if (index.engine_type_ == (int) EngineType::FAISS_IDMAP) { + if (index.engine_type_ == static_cast(EngineType::FAISS_IDMAP)) { file_types = { - (int) meta::TableFileSchema::NEW, - (int) meta::TableFileSchema::NEW_MERGE, + static_cast(meta::TableFileSchema::NEW), + static_cast(meta::TableFileSchema::NEW_MERGE), }; } else { file_types = { - (int) meta::TableFileSchema::RAW, - (int) meta::TableFileSchema::NEW, - (int) meta::TableFileSchema::NEW_MERGE, - (int) meta::TableFileSchema::NEW_INDEX, - (int) meta::TableFileSchema::TO_INDEX, + static_cast(meta::TableFileSchema::RAW), + static_cast(meta::TableFileSchema::NEW), + static_cast(meta::TableFileSchema::NEW_MERGE), + static_cast(meta::TableFileSchema::NEW_INDEX), + static_cast(meta::TableFileSchema::TO_INDEX), }; } @@ -322,7 +316,7 @@ DBImpl::CreateIndex(const std::string &table_id, const TableIndex &index) { while (!file_ids.empty()) { ENGINE_LOG_DEBUG << "Non index files detected! Will build index " << times; - if (index.engine_type_ != (int) EngineType::FAISS_IDMAP) { + if (index.engine_type_ != (int)EngineType::FAISS_IDMAP) { status = meta_ptr_->UpdateTableFilesToIndex(table_id); } @@ -335,19 +329,19 @@ DBImpl::CreateIndex(const std::string &table_id, const TableIndex &index) { } Status -DBImpl::DescribeIndex(const std::string &table_id, TableIndex &index) { +DBImpl::DescribeIndex(const std::string& table_id, TableIndex& index) { return meta_ptr_->DescribeTableIndex(table_id, index); } Status -DBImpl::DropIndex(const std::string &table_id) { +DBImpl::DropIndex(const std::string& table_id) { ENGINE_LOG_DEBUG << "Drop index for table: " << table_id; return meta_ptr_->DropTableIndex(table_id); } Status -DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe, - const float *vectors, QueryResults &results) { +DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, + QueryResults& results) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -359,46 +353,47 @@ DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t npr } Status -DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe, - const float *vectors, const meta::DatesT &dates, QueryResults &results) { +DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, + const meta::DatesT& dates, QueryResults& results) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } ENGINE_LOG_DEBUG << "Query by dates for table: " << table_id; - //get all table files from table + // get all table files from table meta::DatePartionedTableFilesSchema files; std::vector ids; auto status = meta_ptr_->FilesToSearch(table_id, ids, dates, files); - if (!status.ok()) { return status; } + if (!status.ok()) { + return status; + } meta::TableFilesSchema file_id_array; - for (auto &day_files : files) { - for (auto &file : day_files.second) { + for (auto& day_files : files) { + for (auto& file : day_files.second) { file_id_array.push_back(file); } } - cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info before query + cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results); - cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info after query + cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query return status; } Status -DBImpl::Query(const std::string &table_id, const std::vector &file_ids, - uint64_t k, uint64_t nq, uint64_t nprobe, const float *vectors, - const meta::DatesT &dates, QueryResults &results) { +DBImpl::Query(const std::string& table_id, const std::vector& file_ids, uint64_t k, uint64_t nq, + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } ENGINE_LOG_DEBUG << "Query by file ids for table: " << table_id; - //get specified files + // get specified files std::vector ids; - for (auto &id : file_ids) { + for (auto& id : file_ids) { meta::TableFileSchema table_file; table_file.table_id_ = table_id; std::string::size_type sz; @@ -412,8 +407,8 @@ DBImpl::Query(const std::string &table_id, const std::vector &file_ } meta::TableFilesSchema file_id_array; - for (auto &day_files : files_array) { - for (auto &file : day_files.second) { + for (auto& day_files : files_array) { + for (auto& file : day_files.second) { file_id_array.push_back(file); } } @@ -422,14 +417,14 @@ DBImpl::Query(const std::string &table_id, const std::vector &file_ return Status(DB_ERROR, "Invalid file id"); } - cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info before query + cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results); - cache::CpuCacheMgr::GetInstance()->PrintInfo(); //print cache info after query + cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query return status; } Status -DBImpl::Size(uint64_t &result) { +DBImpl::Size(uint64_t& result) { if (shutting_down_.load(std::memory_order_acquire)) { return Status(DB_ERROR, "Milsvus server is shutdown!"); } @@ -438,58 +433,57 @@ DBImpl::Size(uint64_t &result) { } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -//internal methods +// internal methods /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// Status -DBImpl::QueryAsync(const std::string &table_id, const meta::TableFilesSchema &files, - uint64_t k, uint64_t nq, uint64_t nprobe, const float *vectors, - const meta::DatesT &dates, QueryResults &results) { +DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) { server::CollectQueryMetrics metrics(nq); TimeRecorder rc(""); - //step 1: get files to search - ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size() << " date range count: " - << dates.size(); + // step 1: get files to search + ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size() + << " date range count: " << dates.size(); scheduler::SearchJobPtr job = std::make_shared(0, k, nq, nprobe, vectors); - for (auto &file : files) { + for (auto& file : files) { scheduler::TableFileSchemaPtr file_ptr = std::make_shared(file); job->AddIndexFile(file_ptr); } - //step 2: put search task to scheduler + // step 2: put search task to scheduler scheduler::JobMgrInst::GetInstance()->Put(job); job->WaitResult(); if (!job->GetStatus().ok()) { return job->GetStatus(); } - //step 3: print time cost information -// double load_cost = context->LoadCost(); -// double search_cost = context->SearchCost(); -// double reduce_cost = context->ReduceCost(); -// std::string load_info = TimeRecorder::GetTimeSpanStr(load_cost); -// std::string search_info = TimeRecorder::GetTimeSpanStr(search_cost); -// std::string reduce_info = TimeRecorder::GetTimeSpanStr(reduce_cost); -// if(search_cost > 0.0 || reduce_cost > 0.0) { -// double total_cost = load_cost + search_cost + reduce_cost; -// double load_percent = load_cost/total_cost; -// double search_percent = search_cost/total_cost; -// double reduce_percent = reduce_cost/total_cost; -// -// ENGINE_LOG_DEBUG << "Engine load index totally cost: " << load_info -// << " percent: " << load_percent*100 << "%"; -// ENGINE_LOG_DEBUG << "Engine search index totally cost: " << search_info -// << " percent: " << search_percent*100 << "%"; -// ENGINE_LOG_DEBUG << "Engine reduce topk totally cost: " << reduce_info -// << " percent: " << reduce_percent*100 << "%"; -// } else { -// ENGINE_LOG_DEBUG << "Engine load cost: " << load_info -// << " search cost: " << search_info -// << " reduce cost: " << reduce_info; -// } - - //step 4: construct results + // step 3: print time cost information + // double load_cost = context->LoadCost(); + // double search_cost = context->SearchCost(); + // double reduce_cost = context->ReduceCost(); + // std::string load_info = TimeRecorder::GetTimeSpanStr(load_cost); + // std::string search_info = TimeRecorder::GetTimeSpanStr(search_cost); + // std::string reduce_info = TimeRecorder::GetTimeSpanStr(reduce_cost); + // if(search_cost > 0.0 || reduce_cost > 0.0) { + // double total_cost = load_cost + search_cost + reduce_cost; + // double load_percent = load_cost/total_cost; + // double search_percent = search_cost/total_cost; + // double reduce_percent = reduce_cost/total_cost; + // + // ENGINE_LOG_DEBUG << "Engine load index totally cost: " << load_info + // << " percent: " << load_percent*100 << "%"; + // ENGINE_LOG_DEBUG << "Engine search index totally cost: " << search_info + // << " percent: " << search_percent*100 << "%"; + // ENGINE_LOG_DEBUG << "Engine reduce topk totally cost: " << reduce_info + // << " percent: " << reduce_percent*100 << "%"; + // } else { + // ENGINE_LOG_DEBUG << "Engine load cost: " << load_info + // << " search cost: " << search_info + // << " reduce cost: " << reduce_info; + // } + + // step 4: construct results results = job->GetResult(); rc.ElapseFromBegin("Engine query totally cost"); @@ -520,7 +514,7 @@ DBImpl::BackgroundTimerTask() { void DBImpl::WaitMergeFileFinish() { std::lock_guard lck(compact_result_mutex_); - for (auto &iter : compact_thread_results_) { + for (auto& iter : compact_thread_results_) { iter.wait(); } } @@ -528,7 +522,7 @@ DBImpl::WaitMergeFileFinish() { void DBImpl::WaitBuildIndexFinish() { std::lock_guard lck(index_result_mutex_); - for (auto &iter : index_thread_results_) { + for (auto& iter : index_thread_results_) { iter.wait(); } } @@ -569,7 +563,7 @@ DBImpl::MemSerialize() { std::lock_guard lck(mem_serialize_mutex_); std::set temp_table_ids; mem_mgr_->Serialize(temp_table_ids); - for (auto &id : temp_table_ids) { + for (auto& id : temp_table_ids) { compact_table_ids_.insert(id); } @@ -588,10 +582,10 @@ DBImpl::StartCompactionTask() { return; } - //serialize memory data + // serialize memory data MemSerialize(); - //compactiong has been finished? + // compactiong has been finished? { std::lock_guard lck(compact_result_mutex_); if (!compact_thread_results_.empty()) { @@ -602,7 +596,7 @@ DBImpl::StartCompactionTask() { } } - //add new compaction task + // add new compaction task { std::lock_guard lck(compact_result_mutex_); if (compact_thread_results_.empty()) { @@ -614,11 +608,10 @@ DBImpl::StartCompactionTask() { } Status -DBImpl::MergeFiles(const std::string &table_id, const meta::DateT &date, - const meta::TableFilesSchema &files) { +DBImpl::MergeFiles(const std::string& table_id, const meta::DateT& date, const meta::TableFilesSchema& files) { ENGINE_LOG_DEBUG << "Merge files for table: " << table_id; - //step 1: create table file + // step 1: create table file meta::TableFileSchema table_file; table_file.table_id_ = table_id; table_file.date_ = date; @@ -630,15 +623,15 @@ DBImpl::MergeFiles(const std::string &table_id, const meta::DateT &date, return status; } - //step 2: merge files + // step 2: merge files ExecutionEnginePtr index = - EngineFactory::Build(table_file.dimension_, table_file.location_, (EngineType) table_file.engine_type_, - (MetricType) table_file.metric_type_, table_file.nlist_); + EngineFactory::Build(table_file.dimension_, table_file.location_, (EngineType)table_file.engine_type_, + (MetricType)table_file.metric_type_, table_file.nlist_); meta::TableFilesSchema updated; int64_t index_size = 0; - for (auto &file : files) { + for (auto& file : files) { server::CollectMergeFilesMetrics metrics; index->Merge(file.location_); @@ -648,14 +641,16 @@ DBImpl::MergeFiles(const std::string &table_id, const meta::DateT &date, ENGINE_LOG_DEBUG << "Merging file " << file_schema.file_id_; index_size = index->Size(); - if (index_size >= file_schema.index_file_size_) break; + if (index_size >= file_schema.index_file_size_) { + break; + } } - //step 3: serialize to disk + // step 3: serialize to disk try { index->Serialize(); - } catch (std::exception &ex) { - //typical error: out of disk space or permition denied + } catch (std::exception& ex) { + // typical error: out of disk space or permition denied std::string msg = "Serialize merged index encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; @@ -669,12 +664,12 @@ DBImpl::MergeFiles(const std::string &table_id, const meta::DateT &date, return Status(DB_ERROR, msg); } - //step 4: update table files state - //if index type isn't IDMAP, set file type to TO_INDEX if file size execeed index_file_size - //else set file type to RAW, no need to build index - if (table_file.engine_type_ != (int) EngineType::FAISS_IDMAP) { - table_file.file_type_ = (index->PhysicalSize() >= table_file.index_file_size_) ? - meta::TableFileSchema::TO_INDEX : meta::TableFileSchema::RAW; + // step 4: update table files state + // if index type isn't IDMAP, set file type to TO_INDEX if file size execeed index_file_size + // else set file type to RAW, no need to build index + if (table_file.engine_type_ != (int)EngineType::FAISS_IDMAP) { + table_file.file_type_ = (index->PhysicalSize() >= table_file.index_file_size_) ? meta::TableFileSchema::TO_INDEX + : meta::TableFileSchema::RAW; } else { table_file.file_type_ = meta::TableFileSchema::RAW; } @@ -682,8 +677,7 @@ DBImpl::MergeFiles(const std::string &table_id, const meta::DateT &date, table_file.row_count_ = index->Count(); updated.push_back(table_file); status = meta_ptr_->UpdateTableFiles(updated); - ENGINE_LOG_DEBUG << "New merged file " << table_file.file_id_ << - " of size " << index->PhysicalSize() << " bytes"; + ENGINE_LOG_DEBUG << "New merged file " << table_file.file_id_ << " of size " << index->PhysicalSize() << " bytes"; if (options_.insert_cache_immediately_) { index->Cache(); @@ -693,7 +687,7 @@ DBImpl::MergeFiles(const std::string &table_id, const meta::DateT &date, } Status -DBImpl::BackgroundMergeFiles(const std::string &table_id) { +DBImpl::BackgroundMergeFiles(const std::string& table_id) { meta::DatePartionedTableFilesSchema raw_files; auto status = meta_ptr_->FilesToMerge(table_id, raw_files); if (!status.ok()) { @@ -702,7 +696,7 @@ DBImpl::BackgroundMergeFiles(const std::string &table_id) { } bool has_merge = false; - for (auto &kv : raw_files) { + for (auto& kv : raw_files) { auto files = kv.second; if (files.size() < options_.merge_trigger_number_) { ENGINE_LOG_DEBUG << "Files number not greater equal than merge trigger number, skip merge action"; @@ -725,7 +719,7 @@ DBImpl::BackgroundCompaction(std::set table_ids) { ENGINE_LOG_TRACE << " Background compaction thread start"; Status status; - for (auto &table_id : table_ids) { + for (auto& table_id : table_ids) { status = BackgroundMergeFiles(table_id); if (!status.ok()) { ENGINE_LOG_ERROR << "Merge files for table " << table_id << " failed: " << status.ToString(); @@ -739,7 +733,7 @@ DBImpl::BackgroundCompaction(std::set table_ids) { meta_ptr_->Archive(); - int ttl = 5 * meta::M_SEC;//default: file will be deleted after 5 minutes + int ttl = 5 * meta::M_SEC; // default: file will be deleted after 5 minutes if (options_.mode_ == DBOptions::MODE::CLUSTER_WRITABLE) { ttl = meta::D_SEC; } @@ -756,7 +750,7 @@ DBImpl::StartBuildIndexTask(bool force) { return; } - //build index has been finished? + // build index has been finished? { std::lock_guard lck(index_result_mutex_); if (!index_thread_results_.empty()) { @@ -767,52 +761,50 @@ DBImpl::StartBuildIndexTask(bool force) { } } - //add new build index task + // add new build index task { std::lock_guard lck(index_result_mutex_); if (index_thread_results_.empty()) { - index_thread_results_.push_back( - index_thread_pool_.enqueue(&DBImpl::BackgroundBuildIndex, this)); + index_thread_results_.push_back(index_thread_pool_.enqueue(&DBImpl::BackgroundBuildIndex, this)); } } } Status -DBImpl::BuildIndex(const meta::TableFileSchema &file) { - ExecutionEnginePtr to_index = - EngineFactory::Build(file.dimension_, file.location_, (EngineType) file.engine_type_, - (MetricType) file.metric_type_, file.nlist_); +DBImpl::BuildIndex(const meta::TableFileSchema& file) { + ExecutionEnginePtr to_index = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_, + (MetricType)file.metric_type_, file.nlist_); if (to_index == nullptr) { ENGINE_LOG_ERROR << "Invalid engine type"; return Status(DB_ERROR, "Invalid engine type"); } try { - //step 1: load index + // step 1: load index Status status = to_index->Load(options_.insert_cache_immediately_); if (!status.ok()) { ENGINE_LOG_ERROR << "Failed to load index file: " << status.ToString(); return status; } - //step 2: create table file + // step 2: create table file meta::TableFileSchema table_file; table_file.table_id_ = file.table_id_; table_file.date_ = file.date_; table_file.file_type_ = - meta::TableFileSchema::NEW_INDEX; //for multi-db-path, distribute index file averagely to each path + meta::TableFileSchema::NEW_INDEX; // for multi-db-path, distribute index file averagely to each path status = meta_ptr_->CreateTableFile(table_file); if (!status.ok()) { ENGINE_LOG_ERROR << "Failed to create table file: " << status.ToString(); return status; } - //step 3: build index + // step 3: build index std::shared_ptr index; try { server::CollectBuildIndexMetrics metrics; - index = to_index->BuildIndex(table_file.location_, (EngineType) table_file.engine_type_); + index = to_index->BuildIndex(table_file.location_, (EngineType)table_file.engine_type_); if (index == nullptr) { table_file.file_type_ = meta::TableFileSchema::TO_DELETE; status = meta_ptr_->UpdateTableFile(table_file); @@ -821,8 +813,8 @@ DBImpl::BuildIndex(const meta::TableFileSchema &file) { return status; } - } catch (std::exception &ex) { - //typical error: out of gpu memory + } catch (std::exception& ex) { + // typical error: out of gpu memory std::string msg = "BuildIndex encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; @@ -836,7 +828,7 @@ DBImpl::BuildIndex(const meta::TableFileSchema &file) { return Status(DB_ERROR, msg); } - //step 4: if table has been deleted, dont save index file + // step 4: if table has been deleted, dont save index file bool has_table = false; meta_ptr_->HasTable(file.table_id_, has_table); if (!has_table) { @@ -844,11 +836,11 @@ DBImpl::BuildIndex(const meta::TableFileSchema &file) { return Status::OK(); } - //step 5: save index file + // step 5: save index file try { index->Serialize(); - } catch (std::exception &ex) { - //typical error: out of disk space or permition denied + } catch (std::exception& ex) { + // typical error: out of disk space or permition denied std::string msg = "Serialize index encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; @@ -862,7 +854,7 @@ DBImpl::BuildIndex(const meta::TableFileSchema &file) { return Status(DB_ERROR, msg); } - //step 6: update meta + // step 6: update meta table_file.file_type_ = meta::TableFileSchema::INDEX; table_file.file_size_ = index->PhysicalSize(); table_file.row_count_ = index->Count(); @@ -873,15 +865,15 @@ DBImpl::BuildIndex(const meta::TableFileSchema &file) { meta::TableFilesSchema update_files = {table_file, origin_file}; status = meta_ptr_->UpdateTableFiles(update_files); if (status.ok()) { - ENGINE_LOG_DEBUG << "New index file " << table_file.file_id_ << " of size " - << index->PhysicalSize() << " bytes" + ENGINE_LOG_DEBUG << "New index file " << table_file.file_id_ << " of size " << index->PhysicalSize() + << " bytes" << " from file " << origin_file.file_id_; if (options_.insert_cache_immediately_) { index->Cache(); } } else { - //failed to update meta, mark the new file as to_delete, don't delete old file + // failed to update meta, mark the new file as to_delete, don't delete old file origin_file.file_type_ = meta::TableFileSchema::TO_INDEX; status = meta_ptr_->UpdateTableFile(origin_file); ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << origin_file.file_id_ << " to to_index"; @@ -890,7 +882,7 @@ DBImpl::BuildIndex(const meta::TableFileSchema &file) { status = meta_ptr_->UpdateTableFile(table_file); ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete"; } - } catch (std::exception &ex) { + } catch (std::exception& ex) { std::string msg = "Build index encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; return Status(DB_ERROR, msg); @@ -907,21 +899,35 @@ DBImpl::BackgroundBuildIndex() { meta::TableFilesSchema to_index_files; meta_ptr_->FilesToIndex(to_index_files); Status status; - for (auto &file : to_index_files) { - status = BuildIndex(file); - if (!status.ok()) { - ENGINE_LOG_ERROR << "Building index for " << file.id_ << " failed: " << status.ToString(); - } - if (shutting_down_.load(std::memory_order_acquire)) { - ENGINE_LOG_DEBUG << "Server will shutdown, skip build index action"; - break; - } + scheduler::BuildIndexJobPtr job = std::make_shared(0, meta_ptr_, options_); + + // step 2: put build index task to scheduler + for (auto& file : to_index_files) { + scheduler::TableFileSchemaPtr file_ptr = std::make_shared(file); + job->AddToIndexFiles(file_ptr); } + scheduler::JobMgrInst::GetInstance()->Put(job); + job->WaitBuildIndexFinish(); + if (!job->GetStatus().ok()) { + Status status = job->GetStatus(); + ENGINE_LOG_ERROR << "Building index failed: " << status.ToString(); + } + + // for (auto &file : to_index_files) { + // status = BuildIndex(file); + // if (!status.ok()) { + // ENGINE_LOG_ERROR << "Building index for " << file.id_ << " failed: " << status.ToString(); + // } + // + // if (shutting_down_.load(std::memory_order_acquire)) { + // ENGINE_LOG_DEBUG << "Server will shutdown, skip build index action"; + // break; + // } + // } ENGINE_LOG_TRACE << "Background build index thread exit"; } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/DBImpl.h b/cpp/src/db/DBImpl.h index 344e290445b8ef91793af4cd2c4903e167e359a0..865b3dfa5361e4ccfa22fb1df223d3d5badd13fa 100644 --- a/cpp/src/db/DBImpl.h +++ b/cpp/src/db/DBImpl.h @@ -19,20 +19,19 @@ #include "DB.h" #include "Types.h" -#include "utils/ThreadPool.h" #include "src/db/insert/MemManager.h" +#include "utils/ThreadPool.h" -#include -#include -#include #include -#include +#include #include +#include +#include #include -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace engine { @@ -44,92 +43,101 @@ class Meta; class DBImpl : public DB { public: - explicit DBImpl(const DBOptions &options); + explicit DBImpl(const DBOptions& options); ~DBImpl(); - Status Start() override; - Status Stop() override; - Status DropAll() override; + Status + Start() override; + Status + Stop() override; + Status + DropAll() override; - Status CreateTable(meta::TableSchema &table_schema) override; + Status + CreateTable(meta::TableSchema& table_schema) override; - Status DeleteTable(const std::string &table_id, const meta::DatesT &dates) override; + Status + DeleteTable(const std::string& table_id, const meta::DatesT& dates) override; - Status DescribeTable(meta::TableSchema &table_schema) override; + Status + DescribeTable(meta::TableSchema& table_schema) override; - Status HasTable(const std::string &table_id, bool &has_or_not) override; + Status + HasTable(const std::string& table_id, bool& has_or_not) override; - Status AllTables(std::vector &table_schema_array) override; + Status + AllTables(std::vector& table_schema_array) override; - Status PreloadTable(const std::string &table_id) override; + Status + PreloadTable(const std::string& table_id) override; - Status UpdateTableFlag(const std::string &table_id, int64_t flag); + Status + UpdateTableFlag(const std::string& table_id, int64_t flag); - Status GetTableRowCount(const std::string &table_id, uint64_t &row_count) override; + Status + GetTableRowCount(const std::string& table_id, uint64_t& row_count) override; - Status InsertVectors(const std::string &table_id, uint64_t n, const float *vectors, IDNumbers &vector_ids) override; + Status + InsertVectors(const std::string& table_id, uint64_t n, const float* vectors, IDNumbers& vector_ids) override; - Status CreateIndex(const std::string &table_id, const TableIndex &index) override; + Status + CreateIndex(const std::string& table_id, const TableIndex& index) override; - Status DescribeIndex(const std::string &table_id, TableIndex &index) override; + Status + DescribeIndex(const std::string& table_id, TableIndex& index) override; - Status DropIndex(const std::string &table_id) override; + Status + DropIndex(const std::string& table_id) override; - Status Query(const std::string &table_id, - uint64_t k, - uint64_t nq, - uint64_t nprobe, - const float *vectors, - QueryResults &results) override; + Status + Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, + QueryResults& results) override; - Status Query(const std::string &table_id, - uint64_t k, - uint64_t nq, - uint64_t nprobe, - const float *vectors, - const meta::DatesT &dates, - QueryResults &results) override; + Status + Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, + const meta::DatesT& dates, QueryResults& results) override; - Status Query(const std::string &table_id, - const std::vector &file_ids, - uint64_t k, - uint64_t nq, - uint64_t nprobe, - const float *vectors, - const meta::DatesT &dates, - QueryResults &results) override; + Status + Query(const std::string& table_id, const std::vector& file_ids, uint64_t k, uint64_t nq, + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) override; - Status Size(uint64_t &result) override; + Status + Size(uint64_t& result) override; private: - Status QueryAsync(const std::string &table_id, - const meta::TableFilesSchema &files, - uint64_t k, - uint64_t nq, - uint64_t nprobe, - const float *vectors, - const meta::DatesT &dates, - QueryResults &results); - - void BackgroundTimerTask(); - void WaitMergeFileFinish(); - void WaitBuildIndexFinish(); - - void StartMetricTask(); - - void StartCompactionTask(); - Status MergeFiles(const std::string &table_id, - const meta::DateT &date, - const meta::TableFilesSchema &files); - Status BackgroundMergeFiles(const std::string &table_id); - void BackgroundCompaction(std::set table_ids); - - void StartBuildIndexTask(bool force = false); - void BackgroundBuildIndex(); - - Status BuildIndex(const meta::TableFileSchema &); - - Status MemSerialize(); + Status + QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, + uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results); + + void + BackgroundTimerTask(); + void + WaitMergeFileFinish(); + void + WaitBuildIndexFinish(); + + void + StartMetricTask(); + + void + StartCompactionTask(); + Status + MergeFiles(const std::string& table_id, const meta::DateT& date, const meta::TableFilesSchema& files); + Status + BackgroundMergeFiles(const std::string& table_id); + void + BackgroundCompaction(std::set table_ids); + + void + StartBuildIndexTask(bool force = false); + void + BackgroundBuildIndex(); + + Status + BuildIndex(const meta::TableFileSchema&); + + Status + MemSerialize(); private: const DBOptions options_; @@ -152,9 +160,7 @@ class DBImpl : public DB { std::list> index_thread_results_; std::mutex build_index_mutex_; -}; // DBImpl - +}; // DBImpl -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/IDGenerator.cpp b/cpp/src/db/IDGenerator.cpp index 78c88b2c035e5a436530bbdfea633eee82855acf..6b150a926a6b367569a4e47340abe0d024ec3d91 100644 --- a/cpp/src/db/IDGenerator.cpp +++ b/cpp/src/db/IDGenerator.cpp @@ -17,11 +17,10 @@ #include "db/IDGenerator.h" -#include #include +#include #include -namespace zilliz { namespace milvus { namespace engine { @@ -32,13 +31,12 @@ constexpr size_t SimpleIDGenerator::MAX_IDS_PER_MICRO; IDNumber SimpleIDGenerator::GetNextIDNumber() { auto now = std::chrono::system_clock::now(); - auto micros = std::chrono::duration_cast( - now.time_since_epoch()).count(); + auto micros = std::chrono::duration_cast(now.time_since_epoch()).count(); return micros * MAX_IDS_PER_MICRO; } void -SimpleIDGenerator::NextIDNumbers(size_t n, IDNumbers &ids) { +SimpleIDGenerator::NextIDNumbers(size_t n, IDNumbers& ids) { if (n > MAX_IDS_PER_MICRO) { NextIDNumbers(n - MAX_IDS_PER_MICRO, ids); NextIDNumbers(MAX_IDS_PER_MICRO, ids); @@ -49,8 +47,7 @@ SimpleIDGenerator::NextIDNumbers(size_t n, IDNumbers &ids) { } auto now = std::chrono::system_clock::now(); - auto micros = std::chrono::duration_cast( - now.time_since_epoch()).count(); + auto micros = std::chrono::duration_cast(now.time_since_epoch()).count(); micros *= MAX_IDS_PER_MICRO; for (int pos = 0; pos < n; ++pos) { @@ -59,11 +56,10 @@ SimpleIDGenerator::NextIDNumbers(size_t n, IDNumbers &ids) { } void -SimpleIDGenerator::GetNextIDNumbers(size_t n, IDNumbers &ids) { +SimpleIDGenerator::GetNextIDNumbers(size_t n, IDNumbers& ids) { ids.clear(); NextIDNumbers(n, ids); } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/IDGenerator.h b/cpp/src/db/IDGenerator.h index 0785870c4c9721b660511225db49b9b73313c335..cdd22efc1ae5be96a233fc9fd3456f4f6c91d87c 100644 --- a/cpp/src/db/IDGenerator.h +++ b/cpp/src/db/IDGenerator.h @@ -22,22 +22,19 @@ #include #include -namespace zilliz { namespace milvus { namespace engine { class IDGenerator { public: - virtual - IDNumber GetNextIDNumber() = 0; + virtual IDNumber + GetNextIDNumber() = 0; virtual void - GetNextIDNumbers(size_t n, IDNumbers &ids) = 0; - - virtual - ~IDGenerator() = 0; -}; // IDGenerator + GetNextIDNumbers(size_t n, IDNumbers& ids) = 0; + virtual ~IDGenerator() = 0; +}; // IDGenerator class SimpleIDGenerator : public IDGenerator { public: @@ -47,16 +44,14 @@ class SimpleIDGenerator : public IDGenerator { GetNextIDNumber() override; void - GetNextIDNumbers(size_t n, IDNumbers &ids) override; + GetNextIDNumbers(size_t n, IDNumbers& ids) override; private: void - NextIDNumbers(size_t n, IDNumbers &ids); + NextIDNumbers(size_t n, IDNumbers& ids); static constexpr size_t MAX_IDS_PER_MICRO = 1000; -}; // SimpleIDGenerator - +}; // SimpleIDGenerator -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/Options.cpp b/cpp/src/db/Options.cpp index 4c82f3f736abd781089b75144ba56617075971cb..9d6a7d02368dc96da87ae1911e9590d17ff12784 100644 --- a/cpp/src/db/Options.cpp +++ b/cpp/src/db/Options.cpp @@ -19,28 +19,27 @@ #include "utils/Exception.h" #include "utils/Log.h" -#include #include +#include #include -namespace zilliz { namespace milvus { namespace engine { -ArchiveConf::ArchiveConf(const std::string &type, const std::string &criterias) { +ArchiveConf::ArchiveConf(const std::string& type, const std::string& criterias) { ParseType(type); ParseCritirias(criterias); } void -ArchiveConf::SetCriterias(const ArchiveConf::CriteriaT &criterial) { - for (auto &pair : criterial) { +ArchiveConf::SetCriterias(const ArchiveConf::CriteriaT& criterial) { + for (auto& pair : criterial) { criterias_[pair.first] = pair.second; } } void -ArchiveConf::ParseCritirias(const std::string &criterias) { +ArchiveConf::ParseCritirias(const std::string& criterias) { std::stringstream ss(criterias); std::vector tokens; @@ -50,7 +49,7 @@ ArchiveConf::ParseCritirias(const std::string &criterias) { return; } - for (auto &token : tokens) { + for (auto& token : tokens) { if (token.empty()) { continue; } @@ -68,13 +67,11 @@ ArchiveConf::ParseCritirias(const std::string &criterias) { try { auto value = std::stoi(kv[1]); criterias_[kv[0]] = value; - } - catch (std::out_of_range &) { + } catch (std::out_of_range&) { std::string msg = "Out of range: '" + kv[1] + "'"; ENGINE_LOG_ERROR << msg; throw InvalidArgumentException(msg); - } - catch (...) { + } catch (...) { std::string msg = "Invalid argument: '" + kv[1] + "'"; ENGINE_LOG_ERROR << msg; throw InvalidArgumentException(msg); @@ -83,7 +80,7 @@ ArchiveConf::ParseCritirias(const std::string &criterias) { } void -ArchiveConf::ParseType(const std::string &type) { +ArchiveConf::ParseType(const std::string& type) { if (type != "delete" && type != "swap") { std::string msg = "Invalid argument: type='" + type + "'"; throw InvalidArgumentException(msg); @@ -91,6 +88,5 @@ ArchiveConf::ParseType(const std::string &type) { type_ = type; } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/Options.h b/cpp/src/db/Options.h index d8be2767df50227845ed367d83f44810e9788d7e..ebecb4de5a0ff2c51174cb123d9cd6466ae54bde 100644 --- a/cpp/src/db/Options.h +++ b/cpp/src/db/Options.h @@ -19,38 +19,42 @@ #include "Constants.h" -#include -#include #include +#include +#include #include -namespace zilliz { namespace milvus { namespace engine { class Env; -static const char *ARCHIVE_CONF_DISK = "disk"; -static const char *ARCHIVE_CONF_DAYS = "days"; +static const char* ARCHIVE_CONF_DISK = "disk"; +static const char* ARCHIVE_CONF_DAYS = "days"; struct ArchiveConf { using CriteriaT = std::map; - explicit ArchiveConf(const std::string &type, const std::string &criterias = std::string()); + explicit ArchiveConf(const std::string& type, const std::string& criterias = std::string()); - const std::string &GetType() const { + const std::string& + GetType() const { return type_; } - const CriteriaT GetCriterias() const { + const CriteriaT + GetCriterias() const { return criterias_; } - void SetCriterias(const ArchiveConf::CriteriaT &criterial); + void + SetCriterias(const ArchiveConf::CriteriaT& criterial); private: - void ParseCritirias(const std::string &type); - void ParseType(const std::string &criterias); + void + ParseCritirias(const std::string& criterias); + void + ParseType(const std::string& type); std::string type_; CriteriaT criterias_; @@ -61,14 +65,10 @@ struct DBMetaOptions { std::vector slave_paths_; std::string backend_uri_; ArchiveConf archive_conf_ = ArchiveConf("delete"); -}; // DBMetaOptions +}; // DBMetaOptions struct DBOptions { - typedef enum { - SINGLE = 0, - CLUSTER_READONLY, - CLUSTER_WRITABLE - } MODE; + typedef enum { SINGLE = 0, CLUSTER_READONLY, CLUSTER_WRITABLE } MODE; uint16_t merge_trigger_number_ = 2; DBMetaOptions meta_; @@ -76,9 +76,7 @@ struct DBOptions { size_t insert_buffer_size_ = 4 * ONE_GB; bool insert_cache_immediately_ = false; -}; // Options - +}; // Options -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/Types.h b/cpp/src/db/Types.h index 04bf680937c7c9200a48865b89337166668243bf..94528a9a8afb0023ec09ad3d12603c1a3f0a6746 100644 --- a/cpp/src/db/Types.h +++ b/cpp/src/db/Types.h @@ -19,27 +19,25 @@ #include "db/engine/ExecutionEngine.h" -#include #include #include +#include -namespace zilliz { namespace milvus { namespace engine { typedef int64_t IDNumber; -typedef IDNumber *IDNumberPtr; +typedef IDNumber* IDNumberPtr; typedef std::vector IDNumbers; typedef std::vector> QueryResult; typedef std::vector QueryResults; struct TableIndex { - int32_t engine_type_ = (int) EngineType::FAISS_IDMAP; + int32_t engine_type_ = (int)EngineType::FAISS_IDMAP; int32_t nlist_ = 16384; - int32_t metric_type_ = (int) MetricType::L2; + int32_t metric_type_ = (int)MetricType::L2; }; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/Utils.cpp b/cpp/src/db/Utils.cpp index b3d72202cc618e14db0447a7ea14346d094ed1b5..0ddf03568aaa280db059acbab3d7b2f51b277cc4 100644 --- a/cpp/src/db/Utils.cpp +++ b/cpp/src/db/Utils.cpp @@ -19,33 +19,32 @@ #include "utils/CommonUtil.h" #include "utils/Log.h" -#include +#include #include +#include #include #include -#include -namespace zilliz { namespace milvus { namespace engine { namespace utils { namespace { -const char *TABLES_FOLDER = "/tables/"; +const char* TABLES_FOLDER = "/tables/"; uint64_t index_file_counter = 0; std::mutex index_file_counter_mutex; std::string -ConstructParentFolder(const std::string &db_path, const meta::TableFileSchema &table_file) { +ConstructParentFolder(const std::string& db_path, const meta::TableFileSchema& table_file) { std::string table_path = db_path + TABLES_FOLDER + table_file.table_id_; std::string partition_path = table_path + "/" + std::to_string(table_file.date_); return partition_path; } std::string -GetTableFileParentFolder(const DBMetaOptions &options, const meta::TableFileSchema &table_file) { +GetTableFileParentFolder(const DBMetaOptions& options, const meta::TableFileSchema& table_file) { uint64_t path_count = options.slave_paths_.size() + 1; std::string target_path = options.path_; uint64_t index = 0; @@ -70,19 +69,18 @@ GetTableFileParentFolder(const DBMetaOptions &options, const meta::TableFileSche return ConstructParentFolder(target_path, table_file); } -} // namespace +} // namespace int64_t GetMicroSecTimeStamp() { auto now = std::chrono::system_clock::now(); - auto micros = std::chrono::duration_cast( - now.time_since_epoch()).count(); + auto micros = std::chrono::duration_cast(now.time_since_epoch()).count(); return micros; } Status -CreateTablePath(const DBMetaOptions &options, const std::string &table_id) { +CreateTablePath(const DBMetaOptions& options, const std::string& table_id) { std::string db_path = options.path_; std::string table_path = db_path + TABLES_FOLDER + table_id; auto status = server::CommonUtil::CreateDirectory(table_path); @@ -91,7 +89,7 @@ CreateTablePath(const DBMetaOptions &options, const std::string &table_id) { return status; } - for (auto &path : options.slave_paths_) { + for (auto& path : options.slave_paths_) { table_path = path + TABLES_FOLDER + table_id; status = server::CommonUtil::CreateDirectory(table_path); if (!status.ok()) { @@ -104,17 +102,16 @@ CreateTablePath(const DBMetaOptions &options, const std::string &table_id) { } Status -DeleteTablePath(const DBMetaOptions &options, const std::string &table_id, bool force) { +DeleteTablePath(const DBMetaOptions& options, const std::string& table_id, bool force) { std::vector paths = options.slave_paths_; paths.push_back(options.path_); - for (auto &path : paths) { + for (auto& path : paths) { std::string table_path = path + TABLES_FOLDER + table_id; if (force) { boost::filesystem::remove_all(table_path); ENGINE_LOG_DEBUG << "Remove table folder: " << table_path; - } else if (boost::filesystem::exists(table_path) && - boost::filesystem::is_empty(table_path)) { + } else if (boost::filesystem::exists(table_path) && boost::filesystem::is_empty(table_path)) { boost::filesystem::remove_all(table_path); ENGINE_LOG_DEBUG << "Remove table folder: " << table_path; } @@ -124,7 +121,7 @@ DeleteTablePath(const DBMetaOptions &options, const std::string &table_id, bool } Status -CreateTableFilePath(const DBMetaOptions &options, meta::TableFileSchema &table_file) { +CreateTableFilePath(const DBMetaOptions& options, meta::TableFileSchema& table_file) { std::string parent_path = GetTableFileParentFolder(options, table_file); auto status = server::CommonUtil::CreateDirectory(parent_path); @@ -139,46 +136,44 @@ CreateTableFilePath(const DBMetaOptions &options, meta::TableFileSchema &table_f } Status -GetTableFilePath(const DBMetaOptions &options, meta::TableFileSchema &table_file) { +GetTableFilePath(const DBMetaOptions& options, meta::TableFileSchema& table_file) { std::string parent_path = ConstructParentFolder(options.path_, table_file); std::string file_path = parent_path + "/" + table_file.file_id_; if (boost::filesystem::exists(file_path)) { table_file.location_ = file_path; return Status::OK(); - } else { - for (auto &path : options.slave_paths_) { - parent_path = ConstructParentFolder(path, table_file); - file_path = parent_path + "/" + table_file.file_id_; - if (boost::filesystem::exists(file_path)) { - table_file.location_ = file_path; - return Status::OK(); - } + } + + for (auto& path : options.slave_paths_) { + parent_path = ConstructParentFolder(path, table_file); + file_path = parent_path + "/" + table_file.file_id_; + if (boost::filesystem::exists(file_path)) { + table_file.location_ = file_path; + return Status::OK(); } } std::string msg = "Table file doesn't exist: " + file_path; - ENGINE_LOG_ERROR << msg << " in path: " << options.path_ - << " for table: " << table_file.table_id_; + ENGINE_LOG_ERROR << msg << " in path: " << options.path_ << " for table: " << table_file.table_id_; return Status(DB_ERROR, msg); } Status -DeleteTableFilePath(const DBMetaOptions &options, meta::TableFileSchema &table_file) { +DeleteTableFilePath(const DBMetaOptions& options, meta::TableFileSchema& table_file) { utils::GetTableFilePath(options, table_file); boost::filesystem::remove(table_file.location_); return Status::OK(); } bool -IsSameIndex(const TableIndex &index1, const TableIndex &index2) { - return index1.engine_type_ == index2.engine_type_ - && index1.nlist_ == index2.nlist_ - && index1.metric_type_ == index2.metric_type_; +IsSameIndex(const TableIndex& index1, const TableIndex& index2) { + return index1.engine_type_ == index2.engine_type_ && index1.nlist_ == index2.nlist_ && + index1.metric_type_ == index2.metric_type_; } meta::DateT -GetDate(const std::time_t &t, int day_delta) { +GetDate(const std::time_t& t, int day_delta) { struct tm ltm; localtime_r(&t, <m); if (day_delta > 0) { @@ -211,20 +206,15 @@ GetDate() { // URI format: dialect://username:password@host:port/database Status -ParseMetaUri(const std::string &uri, MetaUriInfo &info) { +ParseMetaUri(const std::string& uri, MetaUriInfo& info) { std::string dialect_regex = "(.*)"; std::string username_tegex = "(.*)"; std::string password_regex = "(.*)"; std::string host_regex = "(.*)"; std::string port_regex = "(.*)"; std::string db_name_regex = "(.*)"; - std::string uri_regex_str = - dialect_regex + "\\:\\/\\/" + - username_tegex + "\\:" + - password_regex + "\\@" + - host_regex + "\\:" + - port_regex + "\\/" + - db_name_regex; + std::string uri_regex_str = dialect_regex + "\\:\\/\\/" + username_tegex + "\\:" + password_regex + "\\@" + + host_regex + "\\:" + port_regex + "\\/" + db_name_regex; std::regex uri_regex(uri_regex_str); std::smatch pieces_match; @@ -237,7 +227,7 @@ ParseMetaUri(const std::string &uri, MetaUriInfo &info) { info.port_ = pieces_match[5].str(); info.db_name_ = pieces_match[6].str(); - //TODO: verify host, port... + // TODO(myh): verify host, port... } else { return Status(DB_INVALID_META_URI, "Invalid meta uri: " + uri); } @@ -245,7 +235,6 @@ ParseMetaUri(const std::string &uri, MetaUriInfo &info) { return Status::OK(); } -} // namespace utils -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace utils +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/Utils.h b/cpp/src/db/Utils.h index fbeca84686c8fd9ed5959856d8e6e549bf3a4658..0b157f6dfd932394845b82a4d9685bd9ce52ef6f 100644 --- a/cpp/src/db/Utils.h +++ b/cpp/src/db/Utils.h @@ -18,13 +18,12 @@ #pragma once #include "Options.h" -#include "db/meta/MetaTypes.h" #include "db/Types.h" +#include "db/meta/MetaTypes.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace engine { namespace utils { @@ -33,22 +32,22 @@ int64_t GetMicroSecTimeStamp(); Status -CreateTablePath(const DBMetaOptions &options, const std::string &table_id); +CreateTablePath(const DBMetaOptions& options, const std::string& table_id); Status -DeleteTablePath(const DBMetaOptions &options, const std::string &table_id, bool force = true); +DeleteTablePath(const DBMetaOptions& options, const std::string& table_id, bool force = true); Status -CreateTableFilePath(const DBMetaOptions &options, meta::TableFileSchema &table_file); +CreateTableFilePath(const DBMetaOptions& options, meta::TableFileSchema& table_file); Status -GetTableFilePath(const DBMetaOptions &options, meta::TableFileSchema &table_file); +GetTableFilePath(const DBMetaOptions& options, meta::TableFileSchema& table_file); Status -DeleteTableFilePath(const DBMetaOptions &options, meta::TableFileSchema &table_file); +DeleteTableFilePath(const DBMetaOptions& options, meta::TableFileSchema& table_file); bool -IsSameIndex(const TableIndex &index1, const TableIndex &index2); +IsSameIndex(const TableIndex& index1, const TableIndex& index2); meta::DateT -GetDate(const std::time_t &t, int day_delta = 0); +GetDate(const std::time_t& t, int day_delta = 0); meta::DateT GetDate(); meta::DateT @@ -64,9 +63,8 @@ struct MetaUriInfo { }; Status -ParseMetaUri(const std::string &uri, MetaUriInfo &info); +ParseMetaUri(const std::string& uri, MetaUriInfo& info); -} // namespace utils -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace utils +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/engine/EngineFactory.cpp b/cpp/src/db/engine/EngineFactory.cpp index 043744f8037e808ecf92ab0a8c2fc716ffde237e..c3597d1020106e419c1f66995d7efcde490ce396 100644 --- a/cpp/src/db/engine/EngineFactory.cpp +++ b/cpp/src/db/engine/EngineFactory.cpp @@ -21,22 +21,18 @@ #include -namespace zilliz { namespace milvus { namespace engine { ExecutionEnginePtr -EngineFactory::Build(uint16_t dimension, - const std::string &location, - EngineType index_type, - MetricType metric_type, +EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type, int32_t nlist) { if (index_type == EngineType::INVALID) { ENGINE_LOG_ERROR << "Unsupported engine type"; return nullptr; } - ENGINE_LOG_DEBUG << "EngineFactory index type: " << (int) index_type; + ENGINE_LOG_DEBUG << "EngineFactory index type: " << (int)index_type; ExecutionEnginePtr execution_engine_ptr = std::make_shared(dimension, location, index_type, metric_type, nlist); @@ -44,6 +40,5 @@ EngineFactory::Build(uint16_t dimension, return execution_engine_ptr; } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/engine/EngineFactory.h b/cpp/src/db/engine/EngineFactory.h index 105e48f8856cfd3b4258d83af82b72c7315e5e10..d98952ccd9c99f2ed771600a488d935d2243b558 100644 --- a/cpp/src/db/engine/EngineFactory.h +++ b/cpp/src/db/engine/EngineFactory.h @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "ExecutionEngine.h" @@ -23,20 +22,15 @@ #include -namespace zilliz { namespace milvus { namespace engine { class EngineFactory { public: - static ExecutionEnginePtr Build(uint16_t dimension, - const std::string &location, - EngineType index_type, - MetricType metric_type, - int32_t nlist); + static ExecutionEnginePtr + Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type, + int32_t nlist); }; -} // namespace engine -} // namespace milvus -} // namespace zilliz - +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/engine/ExecutionEngine.h b/cpp/src/db/engine/ExecutionEngine.h index c63760686f310e81a26bf478ddbe1e52cfdcf8c5..848704bd4bbcd633e012a6fa0797c4d8861c198c 100644 --- a/cpp/src/db/engine/ExecutionEngine.h +++ b/cpp/src/db/engine/ExecutionEngine.h @@ -19,11 +19,10 @@ #include "utils/Status.h" -#include #include #include +#include -namespace zilliz { namespace milvus { namespace engine { @@ -43,52 +42,68 @@ enum class MetricType { class ExecutionEngine { public: - virtual Status AddWithIds(int64_t n, const float *xdata, const int64_t *xids) = 0; + virtual Status + AddWithIds(int64_t n, const float* xdata, const int64_t* xids) = 0; + + virtual size_t + Count() const = 0; - virtual size_t Count() const = 0; + virtual size_t + Size() const = 0; - virtual size_t Size() const = 0; + virtual size_t + Dimension() const = 0; - virtual size_t Dimension() const = 0; + virtual size_t + PhysicalSize() const = 0; - virtual size_t PhysicalSize() const = 0; + virtual Status + Serialize() = 0; - virtual Status Serialize() = 0; + virtual Status + Load(bool to_cache = true) = 0; - virtual Status Load(bool to_cache = true) = 0; + virtual Status + CopyToGpu(uint64_t device_id) = 0; - virtual Status CopyToGpu(uint64_t device_id) = 0; + virtual Status + CopyToIndexFileToGpu(uint64_t device_id) = 0; - virtual Status CopyToCpu() = 0; + virtual Status + CopyToCpu() = 0; - virtual std::shared_ptr Clone() = 0; + virtual std::shared_ptr + Clone() = 0; - virtual Status Merge(const std::string &location) = 0; + virtual Status + Merge(const std::string& location) = 0; - virtual Status Search(int64_t n, - const float *data, - int64_t k, - int64_t nprobe, - float *distances, - int64_t *labels) const = 0; + virtual Status + Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels) const = 0; - virtual std::shared_ptr BuildIndex(const std::string &location, EngineType engine_type) = 0; + virtual std::shared_ptr + BuildIndex(const std::string& location, EngineType engine_type) = 0; - virtual Status Cache() = 0; + virtual Status + Cache() = 0; - virtual Status GpuCache(uint64_t gpu_id) = 0; + virtual Status + GpuCache(uint64_t gpu_id) = 0; - virtual Status Init() = 0; + virtual Status + Init() = 0; - virtual EngineType IndexEngineType() const = 0; + virtual EngineType + IndexEngineType() const = 0; - virtual MetricType IndexMetricType() const = 0; + virtual MetricType + IndexMetricType() const = 0; - virtual std::string GetLocation() const = 0; + virtual std::string + GetLocation() const = 0; }; using ExecutionEnginePtr = std::shared_ptr; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/engine/ExecutionEngineImpl.cpp b/cpp/src/db/engine/ExecutionEngineImpl.cpp index a9937092fde81e7fde28c6f3bd7e85da941f25e0..51a31d1143de91f5e17816d9c3622e924f12018b 100644 --- a/cpp/src/db/engine/ExecutionEngineImpl.cpp +++ b/cpp/src/db/engine/ExecutionEngineImpl.cpp @@ -16,39 +16,30 @@ // under the License. #include "db/engine/ExecutionEngineImpl.h" -#include "cache/GpuCacheMgr.h" #include "cache/CpuCacheMgr.h" +#include "cache/GpuCacheMgr.h" #include "metrics/Metrics.h" -#include "utils/Log.h" #include "utils/CommonUtil.h" #include "utils/Exception.h" +#include "utils/Log.h" -#include "src/wrapper/VecIndex.h" -#include "src/wrapper/VecImpl.h" -#include "knowhere/common/Exception.h" #include "knowhere/common/Config.h" -#include "wrapper/ConfAdapterMgr.h" -#include "wrapper/ConfAdapter.h" +#include "knowhere/common/Exception.h" #include "server/Config.h" +#include "src/wrapper/VecImpl.h" +#include "src/wrapper/VecIndex.h" +#include "wrapper/ConfAdapter.h" +#include "wrapper/ConfAdapterMgr.h" #include #include -namespace zilliz { namespace milvus { namespace engine { -ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, - const std::string &location, - EngineType index_type, - MetricType metric_type, - int32_t nlist) - : location_(location), - dim_(dimension), - index_type_(index_type), - metric_type_(metric_type), - nlist_(nlist) { - +ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type, + MetricType metric_type, int32_t nlist) + : location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) { index_ = CreatetVecIndex(EngineType::FAISS_IDMAP); if (!index_) { throw Exception(DB_ERROR, "Could not create VecIndex"); @@ -57,8 +48,7 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, TempMetaConf temp_conf; temp_conf.gpu_id = gpu_num_; temp_conf.dim = dimension; - temp_conf.metric_type = (metric_type_ == MetricType::IP) ? - knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2; + temp_conf.metric_type = (metric_type_ == MetricType::IP) ? knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2; auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType()); auto conf = adapter->Match(temp_conf); @@ -68,16 +58,9 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, } } -ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, - const std::string &location, - EngineType index_type, - MetricType metric_type, - int32_t nlist) - : index_(std::move(index)), - location_(location), - index_type_(index_type), - metric_type_(metric_type), - nlist_(nlist) { +ExecutionEngineImpl::ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type, + MetricType metric_type, int32_t nlist) + : index_(std::move(index)), location_(location), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) { } VecIndexPtr @@ -109,7 +92,7 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { } Status -ExecutionEngineImpl::AddWithIds(int64_t n, const float *xdata, const int64_t *xids) { +ExecutionEngineImpl::AddWithIds(int64_t n, const float* xdata, const int64_t* xids) { auto status = index_->Add(n, xdata, xids); return status; } @@ -125,7 +108,7 @@ ExecutionEngineImpl::Count() const { size_t ExecutionEngineImpl::Size() const { - return (size_t) (Count() * Dimension()) * sizeof(float); + return (size_t)(Count() * Dimension()) * sizeof(float); } size_t @@ -164,7 +147,7 @@ ExecutionEngineImpl::Load(bool to_cache) { } else { ENGINE_LOG_DEBUG << "Disk io from: " << location_; } - } catch (std::exception &e) { + } catch (std::exception& e) { ENGINE_LOG_ERROR << e.what(); return Status(DB_ERROR, e.what()); } @@ -191,7 +174,7 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id) { try { index_ = index_->CopyToGpu(device_id); ENGINE_LOG_DEBUG << "CPU to GPU" << device_id; - } catch (std::exception &e) { + } catch (std::exception& e) { ENGINE_LOG_ERROR << e.what(); return Status(DB_ERROR, e.what()); } @@ -204,6 +187,17 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id) { return Status::OK(); } +Status +ExecutionEngineImpl::CopyToIndexFileToGpu(uint64_t device_id) { + auto index = cache::GpuCacheMgr::GetInstance(device_id)->GetIndex(location_); + bool already_in_cache = (index != nullptr); + if (!already_in_cache) { + cache::DataObjPtr obj = std::make_shared(nullptr, PhysicalSize()); + milvus::cache::GpuCacheMgr::GetInstance(device_id)->InsertItem(location_, obj); + } + return Status::OK(); +} + Status ExecutionEngineImpl::CopyToCpu() { auto index = cache::CpuCacheMgr::GetInstance()->GetIndex(location_); @@ -219,7 +213,7 @@ ExecutionEngineImpl::CopyToCpu() { try { index_ = index_->CopyToCpu(); ENGINE_LOG_DEBUG << "GPU to CPU"; - } catch (std::exception &e) { + } catch (std::exception& e) { ENGINE_LOG_ERROR << e.what(); return Status(DB_ERROR, e.what()); } @@ -245,7 +239,7 @@ ExecutionEngineImpl::Clone() { } Status -ExecutionEngineImpl::Merge(const std::string &location) { +ExecutionEngineImpl::Merge(const std::string& location) { if (location == location_) { return Status(DB_ERROR, "Cannot Merge Self"); } @@ -257,7 +251,7 @@ ExecutionEngineImpl::Merge(const std::string &location) { double physical_size = server::CommonUtil::GetFileSize(location); server::CollectExecutionEngineMetrics metrics(physical_size); to_merge = read_index(location); - } catch (std::exception &e) { + } catch (std::exception& e) { ENGINE_LOG_ERROR << e.what(); return Status(DB_ERROR, e.what()); } @@ -280,7 +274,7 @@ ExecutionEngineImpl::Merge(const std::string &location) { } ExecutionEnginePtr -ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_type) { +ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_type) { ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_; auto from_index = std::dynamic_pointer_cast(index_); @@ -298,29 +292,23 @@ ExecutionEngineImpl::BuildIndex(const std::string &location, EngineType engine_t temp_conf.gpu_id = gpu_num_; temp_conf.dim = Dimension(); temp_conf.nlist = nlist_; - temp_conf.metric_type = (metric_type_ == MetricType::IP) ? - knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2; + temp_conf.metric_type = (metric_type_ == MetricType::IP) ? knowhere::METRICTYPE::IP : knowhere::METRICTYPE::L2; temp_conf.size = Count(); auto adapter = AdapterMgr::GetInstance().GetAdapter(to_index->GetType()); auto conf = adapter->Match(temp_conf); - auto status = to_index->BuildAll(Count(), - from_index->GetRawVectors(), - from_index->GetRawIds(), - conf); - if (!status.ok()) { throw Exception(DB_ERROR, status.message()); } + auto status = to_index->BuildAll(Count(), from_index->GetRawVectors(), from_index->GetRawIds(), conf); + if (!status.ok()) { + throw Exception(DB_ERROR, status.message()); + } return std::make_shared(to_index, location, engine_type, metric_type_, nlist_); } Status -ExecutionEngineImpl::Search(int64_t n, - const float *data, - int64_t k, - int64_t nprobe, - float *distances, - int64_t *labels) const { +ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, + int64_t* labels) const { if (index_ == nullptr) { ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to search"; return Status(DB_ERROR, "index is null"); @@ -346,7 +334,7 @@ ExecutionEngineImpl::Search(int64_t n, Status ExecutionEngineImpl::Cache() { cache::DataObjPtr obj = std::make_shared(index_, PhysicalSize()); - zilliz::milvus::cache::CpuCacheMgr::GetInstance()->InsertItem(location_, obj); + milvus::cache::CpuCacheMgr::GetInstance()->InsertItem(location_, obj); return Status::OK(); } @@ -354,7 +342,7 @@ ExecutionEngineImpl::Cache() { Status ExecutionEngineImpl::GpuCache(uint64_t gpu_id) { cache::DataObjPtr obj = std::make_shared(index_, PhysicalSize()); - zilliz::milvus::cache::GpuCacheMgr::GetInstance(gpu_id)->InsertItem(location_, obj); + milvus::cache::GpuCacheMgr::GetInstance(gpu_id)->InsertItem(location_, obj); return Status::OK(); } @@ -362,13 +350,14 @@ ExecutionEngineImpl::GpuCache(uint64_t gpu_id) { // TODO(linxj): remove. Status ExecutionEngineImpl::Init() { - server::Config &config = server::Config::GetInstance(); + server::Config& config = server::Config::GetInstance(); Status s = config.GetDBConfigBuildIndexGPU(gpu_num_); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } return Status::OK(); } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/engine/ExecutionEngineImpl.h b/cpp/src/db/engine/ExecutionEngineImpl.h index 6ea09ddb29948a4cd1c791f2d876415aacfb16e4..56a584999439b87a962a1f762ac497f62621454b 100644 --- a/cpp/src/db/engine/ExecutionEngineImpl.h +++ b/cpp/src/db/engine/ExecutionEngineImpl.h @@ -23,77 +23,89 @@ #include #include -namespace zilliz { namespace milvus { namespace engine { class ExecutionEngineImpl : public ExecutionEngine { public: - ExecutionEngineImpl(uint16_t dimension, - const std::string &location, - EngineType index_type, - MetricType metric_type, + ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type, int32_t nlist); - ExecutionEngineImpl(VecIndexPtr index, - const std::string &location, - EngineType index_type, - MetricType metric_type, + ExecutionEngineImpl(VecIndexPtr index, const std::string& location, EngineType index_type, MetricType metric_type, int32_t nlist); - Status AddWithIds(int64_t n, const float *xdata, const int64_t *xids) override; + Status + AddWithIds(int64_t n, const float* xdata, const int64_t* xids) override; - size_t Count() const override; + size_t + Count() const override; - size_t Size() const override; + size_t + Size() const override; - size_t Dimension() const override; + size_t + Dimension() const override; - size_t PhysicalSize() const override; + size_t + PhysicalSize() const override; - Status Serialize() override; + Status + Serialize() override; - Status Load(bool to_cache) override; + Status + Load(bool to_cache) override; - Status CopyToGpu(uint64_t device_id) override; + Status + CopyToGpu(uint64_t device_id) override; - Status CopyToCpu() override; + Status + CopyToIndexFileToGpu(uint64_t device_id) override; - ExecutionEnginePtr Clone() override; + Status + CopyToCpu() override; - Status Merge(const std::string &location) override; + ExecutionEnginePtr + Clone() override; - Status Search(int64_t n, - const float *data, - int64_t k, - int64_t nprobe, - float *distances, - int64_t *labels) const override; + Status + Merge(const std::string& location) override; - ExecutionEnginePtr BuildIndex(const std::string &location, EngineType engine_type) override; + Status + Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels) const override; - Status Cache() override; + ExecutionEnginePtr + BuildIndex(const std::string& location, EngineType engine_type) override; - Status GpuCache(uint64_t gpu_id) override; + Status + Cache() override; - Status Init() override; + Status + GpuCache(uint64_t gpu_id) override; - EngineType IndexEngineType() const override { + Status + Init() override; + + EngineType + IndexEngineType() const override { return index_type_; } - MetricType IndexMetricType() const override { + MetricType + IndexMetricType() const override { return metric_type_; } - std::string GetLocation() const override { + std::string + GetLocation() const override { return location_; } private: - VecIndexPtr CreatetVecIndex(EngineType type); + VecIndexPtr + CreatetVecIndex(EngineType type); - VecIndexPtr Load(const std::string &location); + VecIndexPtr + Load(const std::string& location); protected: VecIndexPtr index_ = nullptr; @@ -107,6 +119,5 @@ class ExecutionEngineImpl : public ExecutionEngine { int32_t gpu_num_ = 0; }; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemManager.h b/cpp/src/db/insert/MemManager.h index 751633bf67385571d850992cc0a95845952ee725..cc766041657d5b7553107f1c6cdb0f5af837b57c 100644 --- a/cpp/src/db/insert/MemManager.h +++ b/cpp/src/db/insert/MemManager.h @@ -15,38 +15,40 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "utils/Status.h" #include "db/Types.h" +#include "utils/Status.h" -#include #include +#include #include -namespace zilliz { namespace milvus { namespace engine { class MemManager { public: - virtual Status InsertVectors(const std::string &table_id, - size_t n, const float *vectors, IDNumbers &vector_ids) = 0; + virtual Status + InsertVectors(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids) = 0; - virtual Status Serialize(std::set &table_ids) = 0; + virtual Status + Serialize(std::set& table_ids) = 0; - virtual Status EraseMemVector(const std::string &table_id) = 0; + virtual Status + EraseMemVector(const std::string& table_id) = 0; - virtual size_t GetCurrentMutableMem() = 0; + virtual size_t + GetCurrentMutableMem() = 0; - virtual size_t GetCurrentImmutableMem() = 0; + virtual size_t + GetCurrentImmutableMem() = 0; - virtual size_t GetCurrentMem() = 0; -}; // MemManagerAbstract + virtual size_t + GetCurrentMem() = 0; +}; // MemManagerAbstract using MemManagerPtr = std::shared_ptr; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemManagerImpl.cpp b/cpp/src/db/insert/MemManagerImpl.cpp index e555acd28dcebf6bc345866d1f833a0c9ce99e44..69c3397eb9e553c00cf4c23be8a891877c166009 100644 --- a/cpp/src/db/insert/MemManagerImpl.cpp +++ b/cpp/src/db/insert/MemManagerImpl.cpp @@ -15,20 +15,18 @@ // specific language governing permissions and limitations // under the License. - #include "db/insert/MemManagerImpl.h" #include "VectorSource.h" -#include "utils/Log.h" #include "db/Constants.h" +#include "utils/Log.h" #include -namespace zilliz { namespace milvus { namespace engine { MemTablePtr -MemManagerImpl::GetMemByTable(const std::string &table_id) { +MemManagerImpl::GetMemByTable(const std::string& table_id) { auto memIt = mem_id_map_.find(table_id); if (memIt != mem_id_map_.end()) { return memIt->second; @@ -39,10 +37,7 @@ MemManagerImpl::GetMemByTable(const std::string &table_id) { } Status -MemManagerImpl::InsertVectors(const std::string &table_id_, - size_t n_, - const float *vectors_, - IDNumbers &vector_ids_) { +MemManagerImpl::InsertVectors(const std::string& table_id_, size_t n_, const float* vectors_, IDNumbers& vector_ids_) { while (GetCurrentMem() > options_.insert_buffer_size_) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); } @@ -53,10 +48,8 @@ MemManagerImpl::InsertVectors(const std::string &table_id_, } Status -MemManagerImpl::InsertVectorsNoLock(const std::string &table_id, - size_t n, - const float *vectors, - IDNumbers &vector_ids) { +MemManagerImpl::InsertVectorsNoLock(const std::string& table_id, size_t n, const float* vectors, + IDNumbers& vector_ids) { MemTablePtr mem = GetMemByTable(table_id); VectorSourcePtr source = std::make_shared(n, vectors); @@ -73,9 +66,9 @@ Status MemManagerImpl::ToImmutable() { std::unique_lock lock(mutex_); MemIdMap temp_map; - for (auto &kv : mem_id_map_) { + for (auto& kv : mem_id_map_) { if (kv.second->Empty()) { - //empty table, no need to serialize + // empty table, no need to serialize temp_map.insert(kv); } else { immu_mem_list_.push_back(kv.second); @@ -87,11 +80,11 @@ MemManagerImpl::ToImmutable() { } Status -MemManagerImpl::Serialize(std::set &table_ids) { +MemManagerImpl::Serialize(std::set& table_ids) { ToImmutable(); std::unique_lock lock(serialization_mtx_); table_ids.clear(); - for (auto &mem : immu_mem_list_) { + for (auto& mem : immu_mem_list_) { mem->Serialize(); table_ids.insert(mem->GetTableId()); } @@ -100,16 +93,16 @@ MemManagerImpl::Serialize(std::set &table_ids) { } Status -MemManagerImpl::EraseMemVector(const std::string &table_id) { - {//erase MemVector from rapid-insert cache +MemManagerImpl::EraseMemVector(const std::string& table_id) { + { // erase MemVector from rapid-insert cache std::unique_lock lock(mutex_); mem_id_map_.erase(table_id); } - {//erase MemVector from serialize cache + { // erase MemVector from serialize cache std::unique_lock lock(serialization_mtx_); MemList temp_list; - for (auto &mem : immu_mem_list_) { + for (auto& mem : immu_mem_list_) { if (mem->GetTableId() != table_id) { temp_list.push_back(mem); } @@ -123,7 +116,7 @@ MemManagerImpl::EraseMemVector(const std::string &table_id) { size_t MemManagerImpl::GetCurrentMutableMem() { size_t total_mem = 0; - for (auto &kv : mem_id_map_) { + for (auto& kv : mem_id_map_) { auto memTable = kv.second; total_mem += memTable->GetCurrentMem(); } @@ -133,7 +126,7 @@ MemManagerImpl::GetCurrentMutableMem() { size_t MemManagerImpl::GetCurrentImmutableMem() { size_t total_mem = 0; - for (auto &mem_table : immu_mem_list_) { + for (auto& mem_table : immu_mem_list_) { total_mem += mem_table->GetCurrentMem(); } return total_mem; @@ -144,6 +137,5 @@ MemManagerImpl::GetCurrentMem() { return GetCurrentMutableMem() + GetCurrentImmutableMem(); } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemManagerImpl.h b/cpp/src/db/insert/MemManagerImpl.h index 1783adec1809c4eb8397f2d00dad35d585ea1957..862b068d0fa21de42f84f650af4a146e3f65fb10 100644 --- a/cpp/src/db/insert/MemManagerImpl.h +++ b/cpp/src/db/insert/MemManagerImpl.h @@ -15,23 +15,21 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "MemTable.h" #include "MemManager.h" +#include "MemTable.h" #include "db/meta/Meta.h" #include "utils/Status.h" -#include -#include -#include -#include #include +#include #include #include +#include +#include +#include -namespace zilliz { namespace milvus { namespace engine { @@ -39,29 +37,35 @@ class MemManagerImpl : public MemManager { public: using Ptr = std::shared_ptr; - MemManagerImpl(const meta::MetaPtr &meta, const DBOptions &options) - : meta_(meta), options_(options) { + MemManagerImpl(const meta::MetaPtr& meta, const DBOptions& options) : meta_(meta), options_(options) { } - Status InsertVectors(const std::string &table_id, - size_t n, const float *vectors, IDNumbers &vector_ids) override; + Status + InsertVectors(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids) override; - Status Serialize(std::set &table_ids) override; + Status + Serialize(std::set& table_ids) override; - Status EraseMemVector(const std::string &table_id) override; + Status + EraseMemVector(const std::string& table_id) override; - size_t GetCurrentMutableMem() override; + size_t + GetCurrentMutableMem() override; - size_t GetCurrentImmutableMem() override; + size_t + GetCurrentImmutableMem() override; - size_t GetCurrentMem() override; + size_t + GetCurrentMem() override; private: - MemTablePtr GetMemByTable(const std::string &table_id); + MemTablePtr + GetMemByTable(const std::string& table_id); - Status InsertVectorsNoLock(const std::string &table_id, - size_t n, const float *vectors, IDNumbers &vector_ids); - Status ToImmutable(); + Status + InsertVectorsNoLock(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids); + Status + ToImmutable(); using MemIdMap = std::map; using MemList = std::vector; @@ -71,8 +75,7 @@ class MemManagerImpl : public MemManager { DBOptions options_; std::mutex mutex_; std::mutex serialization_mtx_; -}; // NewMemManager +}; // NewMemManager -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemMenagerFactory.cpp b/cpp/src/db/insert/MemMenagerFactory.cpp index 76e521e168931de9aed4f9381d7bf299ecf37e11..033992501b3362908af6d9d20ab026613364be37 100644 --- a/cpp/src/db/insert/MemMenagerFactory.cpp +++ b/cpp/src/db/insert/MemMenagerFactory.cpp @@ -17,26 +17,24 @@ #include "db/insert/MemMenagerFactory.h" #include "MemManagerImpl.h" -#include "utils/Log.h" #include "utils/Exception.h" +#include "utils/Log.h" #include #include -#include #include -#include -#include #include +#include +#include +#include -namespace zilliz { namespace milvus { namespace engine { MemManagerPtr -MemManagerFactory::Build(const std::shared_ptr &meta, const DBOptions &options) { +MemManagerFactory::Build(const std::shared_ptr& meta, const DBOptions& options) { return std::make_shared(meta, options); } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemMenagerFactory.h b/cpp/src/db/insert/MemMenagerFactory.h index fa0cb5930a8e756ca8cf963a71ccdc2d90e80fa5..a489a3952b772a0e985e97b527e8388216a3652f 100644 --- a/cpp/src/db/insert/MemMenagerFactory.h +++ b/cpp/src/db/insert/MemMenagerFactory.h @@ -22,15 +22,14 @@ #include -namespace zilliz { namespace milvus { namespace engine { class MemManagerFactory { public: - static MemManagerPtr Build(const std::shared_ptr &meta, const DBOptions &options); + static MemManagerPtr + Build(const std::shared_ptr& meta, const DBOptions& options); }; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemTable.cpp b/cpp/src/db/insert/MemTable.cpp index 8db6a9d1414c6545d43b2881f4baaec0fd0fb7df..871eab8d7d89e4d398e34de44f5472437e64ef58 100644 --- a/cpp/src/db/insert/MemTable.cpp +++ b/cpp/src/db/insert/MemTable.cpp @@ -15,27 +15,21 @@ // specific language governing permissions and limitations // under the License. - #include "db/insert/MemTable.h" #include "utils/Log.h" #include #include -namespace zilliz { namespace milvus { namespace engine { -MemTable::MemTable(const std::string &table_id, - const meta::MetaPtr &meta, - const DBOptions &options) : - table_id_(table_id), - meta_(meta), - options_(options) { +MemTable::MemTable(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options) + : table_id_(table_id), meta_(meta), options_(options) { } Status -MemTable::Add(VectorSourcePtr &source, IDNumbers &vector_ids) { +MemTable::Add(VectorSourcePtr& source, IDNumbers& vector_ids) { while (!source->AllAdded()) { MemTableFilePtr current_mem_table_file; if (!mem_table_file_list_.empty()) { @@ -63,7 +57,7 @@ MemTable::Add(VectorSourcePtr &source, IDNumbers &vector_ids) { } void -MemTable::GetCurrentMemTableFile(MemTableFilePtr &mem_table_file) { +MemTable::GetCurrentMemTableFile(MemTableFilePtr& mem_table_file) { mem_table_file = mem_table_file_list_.back(); } @@ -92,7 +86,7 @@ MemTable::Empty() { return mem_table_file_list_.empty(); } -const std::string & +const std::string& MemTable::GetTableId() const { return table_id_; } @@ -101,12 +95,11 @@ size_t MemTable::GetCurrentMem() { std::lock_guard lock(mutex_); size_t total_mem = 0; - for (auto &mem_table_file : mem_table_file_list_) { + for (auto& mem_table_file : mem_table_file_list_) { total_mem += mem_table_file->GetCurrentMem(); } return total_mem; } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemTable.h b/cpp/src/db/insert/MemTable.h index da7d914b41cd96cb287e999a22c81ca124651308..cb22b6ed343cbbfd97f4f0dfc0516da251f919f1 100644 --- a/cpp/src/db/insert/MemTable.h +++ b/cpp/src/db/insert/MemTable.h @@ -15,19 +15,17 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "MemTableFile.h" #include "VectorSource.h" #include "utils/Status.h" -#include -#include #include +#include #include +#include -namespace zilliz { namespace milvus { namespace engine { @@ -35,21 +33,28 @@ class MemTable { public: using MemTableFileList = std::vector; - MemTable(const std::string &table_id, const meta::MetaPtr &meta, const DBOptions &options); + MemTable(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options); - Status Add(VectorSourcePtr &source, IDNumbers &vector_ids); + Status + Add(VectorSourcePtr& source, IDNumbers& vector_ids); - void GetCurrentMemTableFile(MemTableFilePtr &mem_table_file); + void + GetCurrentMemTableFile(MemTableFilePtr& mem_table_file); - size_t GetTableFileCount(); + size_t + GetTableFileCount(); - Status Serialize(); + Status + Serialize(); - bool Empty(); + bool + Empty(); - const std::string &GetTableId() const; + const std::string& + GetTableId() const; - size_t GetCurrentMem(); + size_t + GetCurrentMem(); private: const std::string table_id_; @@ -61,10 +66,9 @@ class MemTable { DBOptions options_; std::mutex mutex_; -}; //MemTable +}; // MemTable using MemTablePtr = std::shared_ptr; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemTableFile.cpp b/cpp/src/db/insert/MemTableFile.cpp index fc6c2b319a74d42e8fb1ffb35de775af48f7ce14..2c877c78ed8018816b4b403faac1395073c1942b 100644 --- a/cpp/src/db/insert/MemTableFile.cpp +++ b/cpp/src/db/insert/MemTableFile.cpp @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #include "db/insert/MemTableFile.h" #include "db/Constants.h" #include "db/engine/EngineFactory.h" @@ -25,24 +24,17 @@ #include #include -namespace zilliz { namespace milvus { namespace engine { -MemTableFile::MemTableFile(const std::string &table_id, - const meta::MetaPtr &meta, - const DBOptions &options) : - table_id_(table_id), - meta_(meta), - options_(options) { +MemTableFile::MemTableFile(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options) + : table_id_(table_id), meta_(meta), options_(options) { current_mem_ = 0; auto status = CreateTableFile(); if (status.ok()) { - execution_engine_ = EngineFactory::Build(table_file_schema_.dimension_, - table_file_schema_.location_, - (EngineType) table_file_schema_.engine_type_, - (MetricType) table_file_schema_.metric_type_, - table_file_schema_.nlist_); + execution_engine_ = EngineFactory::Build( + table_file_schema_.dimension_, table_file_schema_.location_, (EngineType)table_file_schema_.engine_type_, + (MetricType)table_file_schema_.metric_type_, table_file_schema_.nlist_); } } @@ -61,10 +53,11 @@ MemTableFile::CreateTableFile() { } Status -MemTableFile::Add(const VectorSourcePtr &source, IDNumbers &vector_ids) { +MemTableFile::Add(const VectorSourcePtr& source, IDNumbers& vector_ids) { if (table_file_schema_.dimension_ <= 0) { - std::string err_msg = "MemTableFile::Add: table_file_schema dimension = " + - std::to_string(table_file_schema_.dimension_) + ", table_id = " + table_file_schema_.table_id_; + std::string err_msg = + "MemTableFile::Add: table_file_schema dimension = " + std::to_string(table_file_schema_.dimension_) + + ", table_id = " + table_file_schema_.table_id_; ENGINE_LOG_ERROR << err_msg; return Status(DB_ERROR, "Not able to create table file"); } @@ -109,11 +102,11 @@ MemTableFile::Serialize() { table_file_schema_.file_size_ = execution_engine_->PhysicalSize(); table_file_schema_.row_count_ = execution_engine_->Count(); - //if index type isn't IDMAP, set file type to TO_INDEX if file size execeed index_file_size - //else set file type to RAW, no need to build index - if (table_file_schema_.engine_type_ != (int) EngineType::FAISS_IDMAP) { - table_file_schema_.file_type_ = (size >= table_file_schema_.index_file_size_) ? - meta::TableFileSchema::TO_INDEX : meta::TableFileSchema::RAW; + // if index type isn't IDMAP, set file type to TO_INDEX if file size execeed index_file_size + // else set file type to RAW, no need to build index + if (table_file_schema_.engine_type_ != (int)EngineType::FAISS_IDMAP) { + table_file_schema_.file_type_ = (size >= table_file_schema_.index_file_size_) ? meta::TableFileSchema::TO_INDEX + : meta::TableFileSchema::RAW; } else { table_file_schema_.file_type_ = meta::TableFileSchema::RAW; } @@ -130,6 +123,5 @@ MemTableFile::Serialize() { return status; } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/MemTableFile.h b/cpp/src/db/insert/MemTableFile.h index 1d7019a42d6fcf05c54957bc68a008f6a30a89cb..e11274b7de7ff88a00d6461f5fd6d78528ddbea6 100644 --- a/cpp/src/db/insert/MemTableFile.h +++ b/cpp/src/db/insert/MemTableFile.h @@ -15,37 +15,41 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "VectorSource.h" -#include "db/meta/Meta.h" #include "db/engine/ExecutionEngine.h" +#include "db/meta/Meta.h" #include "utils/Status.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace engine { class MemTableFile { public: - MemTableFile(const std::string &table_id, const meta::MetaPtr &meta, const DBOptions &options); + MemTableFile(const std::string& table_id, const meta::MetaPtr& meta, const DBOptions& options); - Status Add(const VectorSourcePtr &source, IDNumbers &vector_ids); + Status + Add(const VectorSourcePtr& source, IDNumbers& vector_ids); - size_t GetCurrentMem(); + size_t + GetCurrentMem(); - size_t GetMemLeft(); + size_t + GetMemLeft(); - bool IsFull(); + bool + IsFull(); - Status Serialize(); + Status + Serialize(); private: - Status CreateTableFile(); + Status + CreateTableFile(); private: const std::string table_id_; @@ -55,10 +59,9 @@ class MemTableFile { size_t current_mem_; ExecutionEnginePtr execution_engine_; -}; //MemTableFile +}; // MemTableFile using MemTableFilePtr = std::shared_ptr; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/VectorSource.cpp b/cpp/src/db/insert/VectorSource.cpp index 330ef08b98c19dd0c7aceaad4a68588f495b7472..dbedf58029cd6acb65987c0cf438eea5cfbb56e9 100644 --- a/cpp/src/db/insert/VectorSource.cpp +++ b/cpp/src/db/insert/VectorSource.cpp @@ -15,35 +15,27 @@ // specific language governing permissions and limitations // under the License. - #include "db/insert/VectorSource.h" -#include "db/engine/ExecutionEngine.h" #include "db/engine/EngineFactory.h" -#include "utils/Log.h" +#include "db/engine/ExecutionEngine.h" #include "metrics/Metrics.h" +#include "utils/Log.h" -namespace zilliz { namespace milvus { namespace engine { -VectorSource::VectorSource(const size_t &n, - const float *vectors) : - n_(n), - vectors_(vectors), - id_generator_(std::make_shared()) { +VectorSource::VectorSource(const size_t& n, const float* vectors) + : n_(n), vectors_(vectors), id_generator_(std::make_shared()) { current_num_vectors_added = 0; } Status -VectorSource::Add(const ExecutionEnginePtr &execution_engine, - const meta::TableFileSchema &table_file_schema, - const size_t &num_vectors_to_add, - size_t &num_vectors_added, - IDNumbers &vector_ids) { +VectorSource::Add(const ExecutionEnginePtr& execution_engine, const meta::TableFileSchema& table_file_schema, + const size_t& num_vectors_to_add, size_t& num_vectors_added, IDNumbers& vector_ids) { server::CollectAddMetrics metrics(n_, table_file_schema.dimension_); - num_vectors_added = current_num_vectors_added + num_vectors_to_add <= n_ ? - num_vectors_to_add : n_ - current_num_vectors_added; + num_vectors_added = + current_num_vectors_added + num_vectors_to_add <= n_ ? num_vectors_to_add : n_ - current_num_vectors_added; IDNumbers vector_ids_to_add; if (vector_ids.empty()) { id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add); @@ -58,8 +50,7 @@ VectorSource::Add(const ExecutionEnginePtr &execution_engine, vector_ids_to_add.data()); if (status.ok()) { current_num_vectors_added += num_vectors_added; - vector_ids_.insert(vector_ids_.end(), - std::make_move_iterator(vector_ids_to_add.begin()), + vector_ids_.insert(vector_ids_.end(), std::make_move_iterator(vector_ids_to_add.begin()), std::make_move_iterator(vector_ids_to_add.end())); } else { ENGINE_LOG_ERROR << "VectorSource::Add failed: " + status.ToString(); @@ -83,6 +74,5 @@ VectorSource::GetVectorIds() { return vector_ids_; } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/insert/VectorSource.h b/cpp/src/db/insert/VectorSource.h index fd31a14fa6852dbe02e7583423d597f631a591e7..1d936268f4780f037cf13ec8b44db14abd612d4a 100644 --- a/cpp/src/db/insert/VectorSource.h +++ b/cpp/src/db/insert/VectorSource.h @@ -15,48 +15,46 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "db/meta/Meta.h" #include "db/IDGenerator.h" #include "db/engine/ExecutionEngine.h" +#include "db/meta/Meta.h" #include "utils/Status.h" #include -namespace zilliz { namespace milvus { namespace engine { class VectorSource { public: - VectorSource(const size_t &n, const float *vectors); + VectorSource(const size_t& n, const float* vectors); - Status Add(const ExecutionEnginePtr &execution_engine, - const meta::TableFileSchema &table_file_schema, - const size_t &num_vectors_to_add, - size_t &num_vectors_added, - IDNumbers &vector_ids); + Status + Add(const ExecutionEnginePtr& execution_engine, const meta::TableFileSchema& table_file_schema, + const size_t& num_vectors_to_add, size_t& num_vectors_added, IDNumbers& vector_ids); - size_t GetNumVectorsAdded(); + size_t + GetNumVectorsAdded(); - bool AllAdded(); + bool + AllAdded(); - IDNumbers GetVectorIds(); + IDNumbers + GetVectorIds(); private: const size_t n_; - const float *vectors_; + const float* vectors_; IDNumbers vector_ids_; size_t current_num_vectors_added; std::shared_ptr id_generator_; -}; //VectorSource +}; // VectorSource using VectorSourcePtr = std::shared_ptr; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/Meta.h b/cpp/src/db/meta/Meta.h index 42fe6a6dd546fbf7c704489592f560c3f3db7661..8167834568fff65d863c587363259177fe93e323 100644 --- a/cpp/src/db/meta/Meta.h +++ b/cpp/src/db/meta/Meta.h @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "MetaTypes.h" @@ -25,84 +24,102 @@ #include #include -#include #include +#include -namespace zilliz { namespace milvus { namespace engine { namespace meta { -static const char *META_TABLES = "Tables"; -static const char *META_TABLEFILES = "TableFiles"; +static const char* META_TABLES = "Tables"; +static const char* META_TABLEFILES = "TableFiles"; class Meta { public: virtual ~Meta() = default; - virtual Status CreateTable(TableSchema &table_schema) = 0; + virtual Status + CreateTable(TableSchema& table_schema) = 0; - virtual Status DescribeTable(TableSchema &table_schema) = 0; + virtual Status + DescribeTable(TableSchema& table_schema) = 0; - virtual Status HasTable(const std::string &table_id, bool &has_or_not) = 0; + virtual Status + HasTable(const std::string& table_id, bool& has_or_not) = 0; - virtual Status AllTables(std::vector &table_schema_array) = 0; + virtual Status + AllTables(std::vector& table_schema_array) = 0; - virtual Status UpdateTableIndex(const std::string &table_id, const TableIndex &index) = 0; + virtual Status + UpdateTableIndex(const std::string& table_id, const TableIndex& index) = 0; - virtual Status UpdateTableFlag(const std::string &table_id, int64_t flag) = 0; + virtual Status + UpdateTableFlag(const std::string& table_id, int64_t flag) = 0; - virtual Status DeleteTable(const std::string &table_id) = 0; + virtual Status + DeleteTable(const std::string& table_id) = 0; - virtual Status DeleteTableFiles(const std::string &table_id) = 0; + virtual Status + DeleteTableFiles(const std::string& table_id) = 0; - virtual Status CreateTableFile(TableFileSchema &file_schema) = 0; + virtual Status + CreateTableFile(TableFileSchema& file_schema) = 0; - virtual Status DropPartitionsByDates(const std::string &table_id, const DatesT &dates) = 0; + virtual Status + DropPartitionsByDates(const std::string& table_id, const DatesT& dates) = 0; - virtual Status GetTableFiles(const std::string &table_id, - const std::vector &ids, - TableFilesSchema &table_files) = 0; + virtual Status + GetTableFiles(const std::string& table_id, const std::vector& ids, TableFilesSchema& table_files) = 0; - virtual Status UpdateTableFilesToIndex(const std::string &table_id) = 0; + virtual Status + UpdateTableFilesToIndex(const std::string& table_id) = 0; - virtual Status UpdateTableFile(TableFileSchema &file_schema) = 0; + virtual Status + UpdateTableFile(TableFileSchema& file_schema) = 0; - virtual Status UpdateTableFiles(TableFilesSchema &files) = 0; + virtual Status + UpdateTableFiles(TableFilesSchema& files) = 0; - virtual Status FilesToSearch(const std::string &table_id, - const std::vector &ids, - const DatesT &partition, - DatePartionedTableFilesSchema &files) = 0; + virtual Status + FilesToSearch(const std::string& table_id, const std::vector& ids, const DatesT& partition, + DatePartionedTableFilesSchema& files) = 0; - virtual Status FilesToMerge(const std::string &table_id, DatePartionedTableFilesSchema &files) = 0; + virtual Status + FilesToMerge(const std::string& table_id, DatePartionedTableFilesSchema& files) = 0; - virtual Status Size(uint64_t &result) = 0; + virtual Status + Size(uint64_t& result) = 0; - virtual Status Archive() = 0; + virtual Status + Archive() = 0; - virtual Status FilesToIndex(TableFilesSchema &) = 0; + virtual Status + FilesToIndex(TableFilesSchema&) = 0; - virtual Status FilesByType(const std::string &table_id, - const std::vector &file_types, - std::vector &file_ids) = 0; + virtual Status + FilesByType(const std::string& table_id, const std::vector& file_types, + std::vector& file_ids) = 0; - virtual Status DescribeTableIndex(const std::string &table_id, TableIndex &index) = 0; + virtual Status + DescribeTableIndex(const std::string& table_id, TableIndex& index) = 0; - virtual Status DropTableIndex(const std::string &table_id) = 0; + virtual Status + DropTableIndex(const std::string& table_id) = 0; - virtual Status CleanUp() = 0; + virtual Status + CleanUp() = 0; virtual Status CleanUpFilesWithTTL(uint16_t) = 0; - virtual Status DropAll() = 0; + virtual Status + DropAll() = 0; - virtual Status Count(const std::string &table_id, uint64_t &result) = 0; -}; // MetaData + virtual Status + Count(const std::string& table_id, uint64_t& result) = 0; +}; // MetaData using MetaPtr = std::shared_ptr; -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MetaConsts.h b/cpp/src/db/meta/MetaConsts.h index 6abae564e880f3076a372920f8d10db86b486b37..4e40ff77313507ec9866f10b0a2ac0f784fdb71e 100644 --- a/cpp/src/db/meta/MetaConsts.h +++ b/cpp/src/db/meta/MetaConsts.h @@ -17,16 +17,10 @@ #pragma once -namespace zilliz { namespace milvus { namespace engine { namespace meta { -const size_t K = 1024UL; -const size_t M = K * K; -const size_t G = K * M; -const size_t T = K * G; - const size_t S_PS = 1UL; const size_t MS_PS = 1000 * S_PS; const size_t US_PS = 1000 * MS_PS; @@ -38,7 +32,6 @@ const size_t H_SEC = 60 * M_SEC; const size_t D_SEC = 24 * H_SEC; const size_t W_SEC = 7 * D_SEC; -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MetaFactory.cpp b/cpp/src/db/meta/MetaFactory.cpp index dd0b2ae7eab7691ab1f1fd57e6baab3e056e1146..8031038e3701861b0f36bf52d6c0c8a89e911c73 100644 --- a/cpp/src/db/meta/MetaFactory.cpp +++ b/cpp/src/db/meta/MetaFactory.cpp @@ -16,33 +16,31 @@ // under the License. #include "db/meta/MetaFactory.h" -#include "SqliteMetaImpl.h" #include "MySQLMetaImpl.h" -#include "utils/Log.h" -#include "utils/Exception.h" +#include "SqliteMetaImpl.h" #include "db/Utils.h" +#include "utils/Exception.h" +#include "utils/Log.h" #include +#include #include -#include #include -#include -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace engine { DBMetaOptions -MetaFactory::BuildOption(const std::string &path) { +MetaFactory::BuildOption(const std::string& path) { auto p = path; if (p == "") { srand(time(nullptr)); std::stringstream ss; - uint32_t rd = 0; - rand_r(&rd); - ss << "/tmp/" << rd; + uint32_t seed = 1; + ss << "/tmp/" << rand_r(&seed); p = ss.str(); } @@ -52,7 +50,7 @@ MetaFactory::BuildOption(const std::string &path) { } meta::MetaPtr -MetaFactory::Build(const DBMetaOptions &metaOptions, const int &mode) { +MetaFactory::Build(const DBMetaOptions& metaOptions, const int& mode) { std::string uri = metaOptions.backend_uri_; utils::MetaUriInfo uri_info; @@ -74,6 +72,5 @@ MetaFactory::Build(const DBMetaOptions &metaOptions, const int &mode) { } } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MetaFactory.h b/cpp/src/db/meta/MetaFactory.h index a13a09715ad40bbbf42752844493e6ce8b9230a0..f16584426ed42af0a9e2d4f67b8e26451ba6b37e 100644 --- a/cpp/src/db/meta/MetaFactory.h +++ b/cpp/src/db/meta/MetaFactory.h @@ -22,17 +22,17 @@ #include -namespace zilliz { namespace milvus { namespace engine { class MetaFactory { public: - static DBMetaOptions BuildOption(const std::string &path = ""); + static DBMetaOptions + BuildOption(const std::string& path = ""); - static meta::MetaPtr Build(const DBMetaOptions &metaOptions, const int &mode); + static meta::MetaPtr + Build(const DBMetaOptions& metaOptions, const int& mode); }; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MetaTypes.h b/cpp/src/db/meta/MetaTypes.h index 6fd649bc079e4e222a1797a07b6faf3790822454..c973f3fdeae40f18984d8a5ee937ee558771080d 100644 --- a/cpp/src/db/meta/MetaTypes.h +++ b/cpp/src/db/meta/MetaTypes.h @@ -17,22 +17,21 @@ #pragma once -#include "db/engine/ExecutionEngine.h" #include "db/Constants.h" +#include "db/engine/ExecutionEngine.h" -#include #include -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace engine { namespace meta { -constexpr int32_t DEFAULT_ENGINE_TYPE = (int) EngineType::FAISS_IDMAP; +constexpr int32_t DEFAULT_ENGINE_TYPE = (int)EngineType::FAISS_IDMAP; constexpr int32_t DEFAULT_NLIST = 16384; -constexpr int32_t DEFAULT_METRIC_TYPE = (int) MetricType::L2; +constexpr int32_t DEFAULT_METRIC_TYPE = (int)MetricType::L2; constexpr int32_t DEFAULT_INDEX_FILE_SIZE = ONE_GB; constexpr int64_t FLAG_MASK_NO_USERID = 0x1; @@ -50,7 +49,7 @@ struct TableSchema { size_t id_ = 0; std::string table_id_; - int32_t state_ = (int) NORMAL; + int32_t state_ = (int)NORMAL; uint16_t dimension_ = 0; int64_t created_on_ = 0; int64_t flag_ = 0; @@ -58,7 +57,7 @@ struct TableSchema { int32_t engine_type_ = DEFAULT_ENGINE_TYPE; int32_t nlist_ = DEFAULT_NLIST; int32_t metric_type_ = DEFAULT_METRIC_TYPE; -}; // TableSchema +}; // TableSchema struct TableFileSchema { typedef enum { @@ -83,17 +82,16 @@ struct TableFileSchema { std::string location_; int64_t updated_time_ = 0; int64_t created_on_ = 0; - int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE; //not persist to meta + int64_t index_file_size_ = DEFAULT_INDEX_FILE_SIZE; // not persist to meta int32_t engine_type_ = DEFAULT_ENGINE_TYPE; - int32_t nlist_ = DEFAULT_NLIST; //not persist to meta - int32_t metric_type_ = DEFAULT_METRIC_TYPE; //not persist to meta -}; // TableFileSchema + int32_t nlist_ = DEFAULT_NLIST; // not persist to meta + int32_t metric_type_ = DEFAULT_METRIC_TYPE; // not persist to meta +}; // TableFileSchema using TableFileSchemaPtr = std::shared_ptr; using TableFilesSchema = std::vector; using DatePartionedTableFilesSchema = std::map; -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MySQLConnectionPool.cpp b/cpp/src/db/meta/MySQLConnectionPool.cpp index 9d612d8b54a522c99585d39b4a55d4d8743a53c0..ef013dce95a3f17ac63a4adba5fe3b70a2666df6 100644 --- a/cpp/src/db/meta/MySQLConnectionPool.cpp +++ b/cpp/src/db/meta/MySQLConnectionPool.cpp @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. - #include "db/meta/MySQLConnectionPool.h" -namespace zilliz { namespace milvus { namespace engine { namespace meta { @@ -28,7 +26,7 @@ namespace meta { // already. Can't do this in create() because we're interested in // connections actually in use, not those created. Also note that // we keep our own count; ConnectionPool::size() isn't the same! -mysqlpp::Connection * +mysqlpp::Connection* MySQLConnectionPool::grab() { while (conns_in_use_ > max_pool_size_) { sleep(1); @@ -40,7 +38,7 @@ MySQLConnectionPool::grab() { // Other half of in-use conn count limit void -MySQLConnectionPool::release(const mysqlpp::Connection *pc) { +MySQLConnectionPool::release(const mysqlpp::Connection* pc) { mysqlpp::ConnectionPool::release(pc); if (conns_in_use_ <= 0) { ENGINE_LOG_WARNING << "MySQLConnetionPool::release: conns_in_use_ is less than zero. conns_in_use_ = " @@ -64,27 +62,25 @@ MySQLConnectionPool::getDB() { } // Superclass overrides -mysqlpp::Connection * +mysqlpp::Connection* MySQLConnectionPool::create() { try { // Create connection using the parameters we were passed upon // creation. - mysqlpp::Connection *conn = new mysqlpp::Connection(); + auto conn = new mysqlpp::Connection(); conn->set_option(new mysqlpp::ReconnectOption(true)); - conn->connect(db_.empty() ? 0 : db_.c_str(), - server_.empty() ? 0 : server_.c_str(), - user_.empty() ? 0 : user_.c_str(), - password_.empty() ? 0 : password_.c_str(), - port_); + conn->connect(db_.empty() ? 0 : db_.c_str(), server_.empty() ? 0 : server_.c_str(), + user_.empty() ? 0 : user_.c_str(), password_.empty() ? 0 : password_.c_str(), port_); return conn; - } catch (const mysqlpp::ConnectionFailed &er) { - ENGINE_LOG_ERROR << "Failed to connect to database server" << ": " << er.what(); + } catch (const mysqlpp::ConnectionFailed& er) { + ENGINE_LOG_ERROR << "Failed to connect to database server" + << ": " << er.what(); return nullptr; } } void -MySQLConnectionPool::destroy(mysqlpp::Connection *cp) { +MySQLConnectionPool::destroy(mysqlpp::Connection* cp) { // Our superclass can't know how we created the Connection, so // it delegates destruction to us, to be safe. delete cp; @@ -95,7 +91,6 @@ MySQLConnectionPool::max_idle_time() { return max_idle_time_; } -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MySQLConnectionPool.h b/cpp/src/db/meta/MySQLConnectionPool.h index 761f272bc1860d85b6f9799bc6af2efb2f018c6d..60e353c3c48e3c134301632470228dd53b2e483c 100644 --- a/cpp/src/db/meta/MySQLConnectionPool.h +++ b/cpp/src/db/meta/MySQLConnectionPool.h @@ -15,16 +15,14 @@ // specific language governing permissions and limitations // under the License. +#include -#include "mysql++/mysql++.h" - -#include #include #include +#include #include "utils/Log.h" -namespace zilliz { namespace milvus { namespace engine { namespace meta { @@ -32,20 +30,16 @@ namespace meta { class MySQLConnectionPool : public mysqlpp::ConnectionPool { public: // The object's only constructor - MySQLConnectionPool(std::string dbName, - std::string userName, - std::string passWord, - std::string serverIp, - int port = 0, - int maxPoolSize = 8) : - db_(dbName), - user_(userName), - password_(passWord), - server_(serverIp), - port_(port), - max_pool_size_(maxPoolSize) { + MySQLConnectionPool(std::string dbName, std::string userName, std::string passWord, std::string serverIp, + int port = 0, int maxPoolSize = 8) + : db_(dbName), + user_(userName), + password_(passWord), + server_(serverIp), + port_(port), + max_pool_size_(maxPoolSize) { conns_in_use_ = 0; - max_idle_time_ = 10; //10 seconds + max_idle_time_ = 10; // 10 seconds } // The destructor. We _must_ call ConnectionPool::clear() here, @@ -54,24 +48,30 @@ class MySQLConnectionPool : public mysqlpp::ConnectionPool { clear(); } - mysqlpp::Connection *grab() override; + mysqlpp::Connection* + grab() override; // Other half of in-use conn count limit - void release(const mysqlpp::Connection *pc) override; + void + release(const mysqlpp::Connection* pc) override; -// int getConnectionsInUse(); -// -// void set_max_idle_time(int max_idle); + // int getConnectionsInUse(); + // + // void set_max_idle_time(int max_idle); - std::string getDB(); + std::string + getDB(); protected: // Superclass overrides - mysqlpp::Connection *create() override; + mysqlpp::Connection* + create() override; - void destroy(mysqlpp::Connection *cp) override; + void + destroy(mysqlpp::Connection* cp) override; - unsigned int max_idle_time() override; + unsigned int + max_idle_time() override; private: // Number of connections currently in use @@ -86,7 +86,6 @@ class MySQLConnectionPool : public mysqlpp::ConnectionPool { unsigned int max_idle_time_; }; -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MySQLMetaImpl.cpp b/cpp/src/db/meta/MySQLMetaImpl.cpp index 490eda1916714a8568a04eb52e42a5e3e33d3264..f9f1569a65854c0199af68651e1e136ca606bc13 100644 --- a/cpp/src/db/meta/MySQLMetaImpl.cpp +++ b/cpp/src/db/meta/MySQLMetaImpl.cpp @@ -16,29 +16,28 @@ // under the License. #include "db/meta/MySQLMetaImpl.h" +#include "MetaConsts.h" #include "db/IDGenerator.h" #include "db/Utils.h" -#include "utils/Log.h" -#include "utils/Exception.h" -#include "MetaConsts.h" #include "metrics/Metrics.h" +#include "utils/Exception.h" +#include "utils/Log.h" +#include +#include #include -#include -#include +#include #include #include +#include +#include +#include #include +#include +#include #include -#include #include -#include -#include -#include -#include -#include -namespace zilliz { namespace milvus { namespace engine { namespace meta { @@ -46,40 +45,41 @@ namespace meta { namespace { Status -HandleException(const std::string &desc, const char *what = nullptr) { +HandleException(const std::string& desc, const char* what = nullptr) { if (what == nullptr) { ENGINE_LOG_ERROR << desc; return Status(DB_META_TRANSACTION_FAILED, desc); - } else { - std::string msg = desc + ":" + what; - ENGINE_LOG_ERROR << msg; - return Status(DB_META_TRANSACTION_FAILED, msg); } + + std::string msg = desc + ":" + what; + ENGINE_LOG_ERROR << msg; + return Status(DB_META_TRANSACTION_FAILED, msg); } class MetaField { public: - MetaField(const std::string &name, const std::string &type, const std::string &setting) - : name_(name), - type_(type), - setting_(setting) { + MetaField(const std::string& name, const std::string& type, const std::string& setting) + : name_(name), type_(type), setting_(setting) { } - std::string name() const { + std::string + name() const { return name_; } - std::string ToString() const { + std::string + ToString() const { return name_ + " " + type_ + " " + setting_; } // mysql field type has additional information. for instance, a filed type is defined as 'BIGINT' // we get the type from sql is 'bigint(20)', so we need to ignore the '(20)' - bool IsEqual(const MetaField &field) const { + bool + IsEqual(const MetaField& field) const { size_t name_len_min = field.name_.length() > name_.length() ? name_.length() : field.name_.length(); size_t type_len_min = field.type_.length() > type_.length() ? type_.length() : field.type_.length(); return strncasecmp(field.name_.c_str(), name_.c_str(), name_len_min) == 0 && - strncasecmp(field.type_.c_str(), type_.c_str(), type_len_min) == 0; + strncasecmp(field.type_.c_str(), type_.c_str(), type_len_min) == 0; } private: @@ -91,18 +91,18 @@ class MetaField { using MetaFields = std::vector; class MetaSchema { public: - MetaSchema(const std::string &name, const MetaFields &fields) - : name_(name), - fields_(fields) { + MetaSchema(const std::string& name, const MetaFields& fields) : name_(name), fields_(fields) { } - std::string name() const { + std::string + name() const { return name_; } - std::string ToString() const { + std::string + ToString() const { std::string result; - for (auto &field : fields_) { + for (auto& field : fields_) { if (!result.empty()) { result += ","; } @@ -111,12 +111,13 @@ class MetaSchema { return result; } - //if the outer fields contains all this MetaSchema fields, return true - //otherwise return false - bool IsEqual(const MetaFields &fields) const { + // if the outer fields contains all this MetaSchema fields, return true + // otherwise return false + bool + IsEqual(const MetaFields& fields) const { std::vector found_field; - for (const auto &this_field : fields_) { - for (const auto &outer_field : fields) { + for (const auto& this_field : fields_) { + for (const auto& outer_field : fields) { if (this_field.IsEqual(outer_field)) { found_field.push_back(this_field.name()); break; @@ -132,40 +133,38 @@ class MetaSchema { MetaFields fields_; }; -//Tables schema +// Tables schema static const MetaSchema TABLES_SCHEMA(META_TABLES, { - MetaField("id", "BIGINT", "PRIMARY KEY AUTO_INCREMENT"), - MetaField("table_id", "VARCHAR(255)", "UNIQUE NOT NULL"), - MetaField("state", "INT", "NOT NULL"), - MetaField("dimension", "SMALLINT", "NOT NULL"), - MetaField("created_on", "BIGINT", "NOT NULL"), - MetaField("flag", "BIGINT", "DEFAULT 0 NOT NULL"), - MetaField("index_file_size", "BIGINT", "DEFAULT 1024 NOT NULL"), - MetaField("engine_type", "INT", "DEFAULT 1 NOT NULL"), - MetaField("nlist", "INT", "DEFAULT 16384 NOT NULL"), - MetaField("metric_type", "INT", "DEFAULT 1 NOT NULL"), -}); - -//TableFiles schema + MetaField("id", "BIGINT", "PRIMARY KEY AUTO_INCREMENT"), + MetaField("table_id", "VARCHAR(255)", "UNIQUE NOT NULL"), + MetaField("state", "INT", "NOT NULL"), + MetaField("dimension", "SMALLINT", "NOT NULL"), + MetaField("created_on", "BIGINT", "NOT NULL"), + MetaField("flag", "BIGINT", "DEFAULT 0 NOT NULL"), + MetaField("index_file_size", "BIGINT", "DEFAULT 1024 NOT NULL"), + MetaField("engine_type", "INT", "DEFAULT 1 NOT NULL"), + MetaField("nlist", "INT", "DEFAULT 16384 NOT NULL"), + MetaField("metric_type", "INT", "DEFAULT 1 NOT NULL"), + }); + +// TableFiles schema static const MetaSchema TABLEFILES_SCHEMA(META_TABLEFILES, { - MetaField("id", "BIGINT", "PRIMARY KEY AUTO_INCREMENT"), - MetaField("table_id", "VARCHAR(255)", "NOT NULL"), - MetaField("engine_type", "INT", "DEFAULT 1 NOT NULL"), - MetaField("file_id", "VARCHAR(255)", "NOT NULL"), - MetaField("file_type", "INT", "DEFAULT 0 NOT NULL"), - MetaField("file_size", "BIGINT", "DEFAULT 0 NOT NULL"), - MetaField("row_count", "BIGINT", "DEFAULT 0 NOT NULL"), - MetaField("updated_time", "BIGINT", "NOT NULL"), - MetaField("created_on", "BIGINT", "NOT NULL"), - MetaField("date", "INT", "DEFAULT -1 NOT NULL"), -}); - -} // namespace + MetaField("id", "BIGINT", "PRIMARY KEY AUTO_INCREMENT"), + MetaField("table_id", "VARCHAR(255)", "NOT NULL"), + MetaField("engine_type", "INT", "DEFAULT 1 NOT NULL"), + MetaField("file_id", "VARCHAR(255)", "NOT NULL"), + MetaField("file_type", "INT", "DEFAULT 0 NOT NULL"), + MetaField("file_size", "BIGINT", "DEFAULT 0 NOT NULL"), + MetaField("row_count", "BIGINT", "DEFAULT 0 NOT NULL"), + MetaField("updated_time", "BIGINT", "NOT NULL"), + MetaField("created_on", "BIGINT", "NOT NULL"), + MetaField("date", "INT", "DEFAULT -1 NOT NULL"), + }); + +} // namespace //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -MySQLMetaImpl::MySQLMetaImpl(const DBMetaOptions &options, const int &mode) - : options_(options), - mode_(mode) { +MySQLMetaImpl::MySQLMetaImpl(const DBMetaOptions& options, const int& mode) : options_(options), mode_(mode) { Initialize(); } @@ -173,7 +172,7 @@ MySQLMetaImpl::~MySQLMetaImpl() { } Status -MySQLMetaImpl::NextTableId(std::string &table_id) { +MySQLMetaImpl::NextTableId(std::string& table_id) { std::stringstream ss; SimpleIDGenerator g; ss << g.GetNextIDNumber(); @@ -182,7 +181,7 @@ MySQLMetaImpl::NextTableId(std::string &table_id) { } Status -MySQLMetaImpl::NextFileId(std::string &file_id) { +MySQLMetaImpl::NextFileId(std::string& file_id) { std::stringstream ss; SimpleIDGenerator g; ss << g.GetNextIDNumber(); @@ -201,7 +200,7 @@ MySQLMetaImpl::ValidateMetaSchema() { return; } - auto validate_func = [&](const MetaSchema &schema) { + auto validate_func = [&](const MetaSchema& schema) { mysqlpp::Query query_statement = connectionPtr->query(); query_statement << "DESC " << schema.name() << ";"; @@ -210,14 +209,14 @@ MySQLMetaImpl::ValidateMetaSchema() { try { mysqlpp::StoreQueryResult res = query_statement.store(); for (size_t i = 0; i < res.num_rows(); i++) { - const mysqlpp::Row &row = res[i]; + const mysqlpp::Row& row = res[i]; std::string name, type; row["Field"].to_string(name); row["Type"].to_string(type); exist_fields.push_back(MetaField(name, type, "")); } - } catch (std::exception &e) { + } catch (std::exception& e) { ENGINE_LOG_DEBUG << "Meta table '" << schema.name() << "' not exist and will be created"; } @@ -228,12 +227,12 @@ MySQLMetaImpl::ValidateMetaSchema() { return schema.IsEqual(exist_fields); }; - //verify Tables + // verify Tables if (!validate_func(TABLES_SCHEMA)) { throw Exception(DB_INCOMPATIB_META, "Meta Tables schema is created by Milvus old version"); } - //verufy TableFiles + // verufy TableFiles if (!validate_func(TABLEFILES_SCHEMA)) { throw Exception(DB_INCOMPATIB_META, "Meta TableFiles schema is created by Milvus old version"); } @@ -241,7 +240,7 @@ MySQLMetaImpl::ValidateMetaSchema() { Status MySQLMetaImpl::Initialize() { - //step 1: create db root path + // step 1: create db root path if (!boost::filesystem::is_directory(options_.path_)) { auto ret = boost::filesystem::create_directory(options_.path_); if (!ret) { @@ -253,7 +252,7 @@ MySQLMetaImpl::Initialize() { std::string uri = options_.backend_uri_; - //step 2: parse and check meta uri + // step 2: parse and check meta uri utils::MetaUriInfo uri_info; auto status = utils::ParseMetaUri(uri, uri_info); if (!status.ok()) { @@ -268,7 +267,7 @@ MySQLMetaImpl::Initialize() { throw Exception(DB_INVALID_META_URI, msg); } - //step 3: connect mysql + // step 3: connect mysql int thread_hint = std::thread::hardware_concurrency(); int max_pool_size = (thread_hint == 0) ? 8 : thread_hint; unsigned int port = 0; @@ -276,15 +275,14 @@ MySQLMetaImpl::Initialize() { port = std::stoi(uri_info.port_); } - mysql_connection_pool_ = - std::make_shared(uri_info.db_name_, uri_info.username_, - uri_info.password_, uri_info.host_, port, max_pool_size); + mysql_connection_pool_ = std::make_shared( + uri_info.db_name_, uri_info.username_, uri_info.password_, uri_info.host_, port, max_pool_size); ENGINE_LOG_DEBUG << "MySQL connection pool: maximum pool size = " << std::to_string(max_pool_size); - //step 4: validate to avoid open old version schema + // step 4: validate to avoid open old version schema ValidateMetaSchema(); - //step 5: create meta tables + // step 5: create meta tables try { if (mode_ != DBOptions::MODE::CLUSTER_READONLY) { CleanUp(); @@ -303,8 +301,8 @@ MySQLMetaImpl::Initialize() { } mysqlpp::Query InitializeQuery = connectionPtr->query(); - InitializeQuery << "CREATE TABLE IF NOT EXISTS " << - TABLES_SCHEMA.name() << " (" << TABLES_SCHEMA.ToString() + ");"; + InitializeQuery << "CREATE TABLE IF NOT EXISTS " << TABLES_SCHEMA.name() << " (" + << TABLES_SCHEMA.ToString() + ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::Initialize: " << InitializeQuery.str(); @@ -312,26 +310,25 @@ MySQLMetaImpl::Initialize() { return HandleException("Initialization Error", InitializeQuery.error()); } - InitializeQuery << "CREATE TABLE IF NOT EXISTS " << - TABLEFILES_SCHEMA.name() << " (" << TABLEFILES_SCHEMA.ToString() + ");"; + InitializeQuery << "CREATE TABLE IF NOT EXISTS " << TABLEFILES_SCHEMA.name() << " (" + << TABLEFILES_SCHEMA.ToString() + ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::Initialize: " << InitializeQuery.str(); if (!InitializeQuery.exec()) { return HandleException("Initialization Error", InitializeQuery.error()); } - } //Scoped Connection - } catch (std::exception &e) { + } // Scoped Connection + } catch (std::exception& e) { return HandleException("GENERAL ERROR DURING INITIALIZATION", e.what()); } return Status::OK(); } -// PXU TODO: Temp solution. Will fix later +// TODO(myh): Delete single vecotor by id Status -MySQLMetaImpl::DropPartitionsByDates(const std::string &table_id, - const DatesT &dates) { +MySQLMetaImpl::DropPartitionsByDates(const std::string& table_id, const DatesT& dates) { if (dates.empty()) { return Status::OK(); } @@ -345,11 +342,11 @@ MySQLMetaImpl::DropPartitionsByDates(const std::string &table_id, try { std::stringstream dateListSS; - for (auto &date : dates) { + for (auto& date : dates) { dateListSS << std::to_string(date) << ", "; } std::string dateListStr = dateListSS.str(); - dateListStr = dateListStr.substr(0, dateListStr.size() - 2); //remove the last ", " + dateListStr = dateListStr.substr(0, dateListStr.size() - 2); // remove the last ", " { mysqlpp::ScopedConnection connectionPtr(*mysql_connection_pool_, safe_grab_); @@ -360,12 +357,11 @@ MySQLMetaImpl::DropPartitionsByDates(const std::string &table_id, mysqlpp::Query dropPartitionsByDatesQuery = connectionPtr->query(); - dropPartitionsByDatesQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << "," << - "updated_time = " << utils::GetMicroSecTimeStamp() << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "date in (" << dateListStr << ");"; + dropPartitionsByDatesQuery << "UPDATE " << META_TABLEFILES << " " + << "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << "," + << "updated_time = " << utils::GetMicroSecTimeStamp() << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "date in (" << dateListStr << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DropPartitionsByDates: " << dropPartitionsByDatesQuery.str(); @@ -373,17 +369,17 @@ MySQLMetaImpl::DropPartitionsByDates(const std::string &table_id, return HandleException("QUERY ERROR WHEN DROPPING PARTITIONS BY DATES", dropPartitionsByDatesQuery.error()); } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Successfully drop partitions, table id = " << table_schema.table_id_; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DROPPING PARTITIONS BY DATES", e.what()); } return Status::OK(); } Status -MySQLMetaImpl::CreateTable(TableSchema &table_schema) { +MySQLMetaImpl::CreateTable(TableSchema& table_schema) { try { server::MetricCollector metric; { @@ -398,9 +394,8 @@ MySQLMetaImpl::CreateTable(TableSchema &table_schema) { if (table_schema.table_id_.empty()) { NextTableId(table_schema.table_id_); } else { - createTableQuery << "SELECT state FROM " << - META_TABLES << " " << - "WHERE table_id = " << mysqlpp::quote << table_schema.table_id_ << ";"; + createTableQuery << "SELECT state FROM " << META_TABLES << " " + << "WHERE table_id = " << mysqlpp::quote << table_schema.table_id_ << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTable: " << createTableQuery.str(); @@ -419,7 +414,7 @@ MySQLMetaImpl::CreateTable(TableSchema &table_schema) { table_schema.id_ = -1; table_schema.created_on_ = utils::GetMicroSecTimeStamp(); - std::string id = "NULL"; //auto-increment + std::string id = "NULL"; // auto-increment std::string table_id = table_schema.table_id_; std::string state = std::to_string(table_schema.state_); std::string dimension = std::to_string(table_schema.dimension_); @@ -430,35 +425,32 @@ MySQLMetaImpl::CreateTable(TableSchema &table_schema) { std::string nlist = std::to_string(table_schema.nlist_); std::string metric_type = std::to_string(table_schema.metric_type_); - createTableQuery << "INSERT INTO " << - META_TABLES << " " << - "VALUES(" << id << ", " << mysqlpp::quote << table_id << ", " << - state << ", " << dimension << ", " << created_on << ", " << - flag << ", " << index_file_size << ", " << engine_type << ", " << - nlist << ", " << metric_type << ");"; + createTableQuery << "INSERT INTO " << META_TABLES << " " + << "VALUES(" << id << ", " << mysqlpp::quote << table_id << ", " << state << ", " + << dimension << ", " << created_on << ", " << flag << ", " << index_file_size << ", " + << engine_type << ", " << nlist << ", " << metric_type << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTable: " << createTableQuery.str(); if (mysqlpp::SimpleResult res = createTableQuery.execute()) { - table_schema.id_ = res.insert_id(); //Might need to use SELECT LAST_INSERT_ID()? + table_schema.id_ = res.insert_id(); // Might need to use SELECT LAST_INSERT_ID()? - //Consume all results to avoid "Commands out of sync" error + // Consume all results to avoid "Commands out of sync" error } else { return HandleException("Add Table Error", createTableQuery.error()); } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Successfully create table: " << table_schema.table_id_; return utils::CreateTablePath(options_, table_schema.table_id_); - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN CREATING TABLE", e.what()); } } Status -MySQLMetaImpl::FilesByType(const std::string &table_id, - const std::vector &file_types, - std::vector &file_ids) { +MySQLMetaImpl::FilesByType(const std::string& table_id, const std::vector& file_types, + std::vector& file_ids) { if (file_types.empty()) { return Status(DB_ERROR, "file types array is empty"); } @@ -483,42 +475,49 @@ MySQLMetaImpl::FilesByType(const std::string &table_id, } mysqlpp::Query hasNonIndexFilesQuery = connectionPtr->query(); - //since table_id is a unique column we just need to check whether it exists or not - hasNonIndexFilesQuery << "SELECT file_id, file_type FROM " << - META_TABLEFILES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "file_type in (" << types << ");"; + // since table_id is a unique column we just need to check whether it exists or not + hasNonIndexFilesQuery << "SELECT file_id, file_type FROM " << META_TABLEFILES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "file_type in (" << types << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::FilesByType: " << hasNonIndexFilesQuery.str(); res = hasNonIndexFilesQuery.store(); - } //Scoped Connection + } // Scoped Connection if (res.num_rows() > 0) { int raw_count = 0, new_count = 0, new_merge_count = 0, new_index_count = 0; int to_index_count = 0, index_count = 0, backup_count = 0; - for (auto &resRow : res) { + for (auto& resRow : res) { std::string file_id; resRow["file_id"].to_string(file_id); file_ids.push_back(file_id); int32_t file_type = resRow["file_type"]; switch (file_type) { - case (int) TableFileSchema::RAW:raw_count++; + case (int)TableFileSchema::RAW: + raw_count++; + break; + case (int)TableFileSchema::NEW: + new_count++; break; - case (int) TableFileSchema::NEW:new_count++; + case (int)TableFileSchema::NEW_MERGE: + new_merge_count++; break; - case (int) TableFileSchema::NEW_MERGE:new_merge_count++; + case (int)TableFileSchema::NEW_INDEX: + new_index_count++; break; - case (int) TableFileSchema::NEW_INDEX:new_index_count++; + case (int)TableFileSchema::TO_INDEX: + to_index_count++; break; - case (int) TableFileSchema::TO_INDEX:to_index_count++; + case (int)TableFileSchema::INDEX: + index_count++; break; - case (int) TableFileSchema::INDEX:index_count++; + case (int)TableFileSchema::BACKUP: + backup_count++; break; - case (int) TableFileSchema::BACKUP:backup_count++; + default: break; - default:break; } } @@ -527,7 +526,7 @@ MySQLMetaImpl::FilesByType(const std::string &table_id, << " new_index files:" << new_index_count << " to_index files:" << to_index_count << " index files:" << index_count << " backup files:" << backup_count; } - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN GET FILE BY TYPE", e.what()); } @@ -535,7 +534,7 @@ MySQLMetaImpl::FilesByType(const std::string &table_id, } Status -MySQLMetaImpl::UpdateTableIndex(const std::string &table_id, const TableIndex &index) { +MySQLMetaImpl::UpdateTableIndex(const std::string& table_id, const TableIndex& index) { try { server::MetricCollector metric; @@ -547,33 +546,31 @@ MySQLMetaImpl::UpdateTableIndex(const std::string &table_id, const TableIndex &i } mysqlpp::Query updateTableIndexParamQuery = connectionPtr->query(); - updateTableIndexParamQuery << "SELECT id, state, dimension, created_on FROM " << - META_TABLES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; + updateTableIndexParamQuery << "SELECT id, state, dimension, created_on FROM " << META_TABLES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableIndex: " << updateTableIndexParamQuery.str(); mysqlpp::StoreQueryResult res = updateTableIndexParamQuery.store(); if (res.num_rows() == 1) { - const mysqlpp::Row &resRow = res[0]; + const mysqlpp::Row& resRow = res[0]; size_t id = resRow["id"]; int32_t state = resRow["state"]; uint16_t dimension = resRow["dimension"]; int64_t created_on = resRow["created_on"]; - updateTableIndexParamQuery << "UPDATE " << - META_TABLES << " " << - "SET id = " << id << ", " << - "state = " << state << ", " << - "dimension = " << dimension << ", " << - "created_on = " << created_on << ", " << - "engine_type = " << index.engine_type_ << ", " << - "nlist = " << index.nlist_ << ", " << - "metric_type = " << index.metric_type_ << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << ";"; + updateTableIndexParamQuery << "UPDATE " << META_TABLES << " " + << "SET id = " << id << ", " + << "state = " << state << ", " + << "dimension = " << dimension << ", " + << "created_on = " << created_on << ", " + << "engine_type = " << index.engine_type_ << ", " + << "nlist = " << index.nlist_ << ", " + << "metric_type = " << index.metric_type_ << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableIndex: " << updateTableIndexParamQuery.str(); @@ -584,10 +581,10 @@ MySQLMetaImpl::UpdateTableIndex(const std::string &table_id, const TableIndex &i } else { return Status(DB_NOT_FOUND, "Table " + table_id + " not found"); } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Successfully update table index, table id = " << table_id; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN UPDATING TABLE INDEX PARAM", e.what()); } @@ -595,7 +592,7 @@ MySQLMetaImpl::UpdateTableIndex(const std::string &table_id, const TableIndex &i } Status -MySQLMetaImpl::UpdateTableFlag(const std::string &table_id, int64_t flag) { +MySQLMetaImpl::UpdateTableFlag(const std::string& table_id, int64_t flag) { try { server::MetricCollector metric; @@ -607,20 +604,19 @@ MySQLMetaImpl::UpdateTableFlag(const std::string &table_id, int64_t flag) { } mysqlpp::Query updateTableFlagQuery = connectionPtr->query(); - updateTableFlagQuery << "UPDATE " << - META_TABLES << " " << - "SET flag = " << flag << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << ";"; + updateTableFlagQuery << "UPDATE " << META_TABLES << " " + << "SET flag = " << flag << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableFlag: " << updateTableFlagQuery.str(); if (!updateTableFlagQuery.exec()) { return HandleException("QUERY ERROR WHEN UPDATING TABLE FLAG", updateTableFlagQuery.error()); } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Successfully update table flag, table id = " << table_id; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN UPDATING TABLE FLAG", e.what()); } @@ -628,7 +624,7 @@ MySQLMetaImpl::UpdateTableFlag(const std::string &table_id, int64_t flag) { } Status -MySQLMetaImpl::DescribeTableIndex(const std::string &table_id, TableIndex &index) { +MySQLMetaImpl::DescribeTableIndex(const std::string& table_id, TableIndex& index) { try { server::MetricCollector metric; @@ -640,17 +636,17 @@ MySQLMetaImpl::DescribeTableIndex(const std::string &table_id, TableIndex &index } mysqlpp::Query describeTableIndexQuery = connectionPtr->query(); - describeTableIndexQuery << "SELECT engine_type, nlist, index_file_size, metric_type FROM " << - META_TABLES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; + describeTableIndexQuery << "SELECT engine_type, nlist, index_file_size, metric_type FROM " << META_TABLES + << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DescribeTableIndex: " << describeTableIndexQuery.str(); mysqlpp::StoreQueryResult res = describeTableIndexQuery.store(); if (res.num_rows() == 1) { - const mysqlpp::Row &resRow = res[0]; + const mysqlpp::Row& resRow = res[0]; index.engine_type_ = resRow["engine_type"]; index.nlist_ = resRow["nlist"]; @@ -658,8 +654,8 @@ MySQLMetaImpl::DescribeTableIndex(const std::string &table_id, TableIndex &index } else { return Status(DB_NOT_FOUND, "Table " + table_id + " not found"); } - } //Scoped Connection - } catch (std::exception &e) { + } // Scoped Connection + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN UPDATING TABLE FLAG", e.what()); } @@ -667,7 +663,7 @@ MySQLMetaImpl::DescribeTableIndex(const std::string &table_id, TableIndex &index } Status -MySQLMetaImpl::DropTableIndex(const std::string &table_id) { +MySQLMetaImpl::DropTableIndex(const std::string& table_id) { try { server::MetricCollector metric; @@ -680,13 +676,12 @@ MySQLMetaImpl::DropTableIndex(const std::string &table_id) { mysqlpp::Query dropTableIndexQuery = connectionPtr->query(); - //soft delete index files - dropTableIndexQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << "," << - "updated_time = " << utils::GetMicroSecTimeStamp() << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "file_type = " << std::to_string(TableFileSchema::INDEX) << ";"; + // soft delete index files + dropTableIndexQuery << "UPDATE " << META_TABLEFILES << " " + << "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << "," + << "updated_time = " << utils::GetMicroSecTimeStamp() << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "file_type = " << std::to_string(TableFileSchema::INDEX) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DropTableIndex: " << dropTableIndexQuery.str(); @@ -694,13 +689,12 @@ MySQLMetaImpl::DropTableIndex(const std::string &table_id) { return HandleException("QUERY ERROR WHEN DROPPING TABLE INDEX", dropTableIndexQuery.error()); } - //set all backup file to raw - dropTableIndexQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET file_type = " << std::to_string(TableFileSchema::RAW) << "," << - "updated_time = " << utils::GetMicroSecTimeStamp() << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "file_type = " << std::to_string(TableFileSchema::BACKUP) << ";"; + // set all backup file to raw + dropTableIndexQuery << "UPDATE " << META_TABLEFILES << " " + << "SET file_type = " << std::to_string(TableFileSchema::RAW) << "," + << "updated_time = " << utils::GetMicroSecTimeStamp() << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "file_type = " << std::to_string(TableFileSchema::BACKUP) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DropTableIndex: " << dropTableIndexQuery.str(); @@ -708,23 +702,22 @@ MySQLMetaImpl::DropTableIndex(const std::string &table_id) { return HandleException("QUERY ERROR WHEN DROPPING TABLE INDEX", dropTableIndexQuery.error()); } - //set table index type to raw - dropTableIndexQuery << "UPDATE " << - META_TABLES << " " << - "SET engine_type = " << std::to_string(DEFAULT_ENGINE_TYPE) << "," << - "nlist = " << std::to_string(DEFAULT_NLIST) << ", " << - "metric_type = " << std::to_string(DEFAULT_METRIC_TYPE) << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << ";"; + // set table index type to raw + dropTableIndexQuery << "UPDATE " << META_TABLES << " " + << "SET engine_type = " << std::to_string(DEFAULT_ENGINE_TYPE) << "," + << "nlist = " << std::to_string(DEFAULT_NLIST) << ", " + << "metric_type = " << std::to_string(DEFAULT_METRIC_TYPE) << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DropTableIndex: " << dropTableIndexQuery.str(); if (!dropTableIndexQuery.exec()) { return HandleException("QUERY ERROR WHEN DROPPING TABLE INDEX", dropTableIndexQuery.error()); } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Successfully drop table index, table id = " << table_id; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DROPPING TABLE INDEX", e.what()); } @@ -732,7 +725,7 @@ MySQLMetaImpl::DropTableIndex(const std::string &table_id) { } Status -MySQLMetaImpl::DeleteTable(const std::string &table_id) { +MySQLMetaImpl::DeleteTable(const std::string& table_id) { try { server::MetricCollector metric; { @@ -742,27 +735,26 @@ MySQLMetaImpl::DeleteTable(const std::string &table_id) { return Status(DB_ERROR, "Failed to connect to database server"); } - //soft delete table + // soft delete table mysqlpp::Query deleteTableQuery = connectionPtr->query(); -// - deleteTableQuery << "UPDATE " << - META_TABLES << " " << - "SET state = " << std::to_string(TableSchema::TO_DELETE) << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << ";"; + // + deleteTableQuery << "UPDATE " << META_TABLES << " " + << "SET state = " << std::to_string(TableSchema::TO_DELETE) << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DeleteTable: " << deleteTableQuery.str(); if (!deleteTableQuery.exec()) { return HandleException("QUERY ERROR WHEN DELETING TABLE", deleteTableQuery.error()); } - } //Scoped Connection + } // Scoped Connection if (mode_ == DBOptions::MODE::CLUSTER_WRITABLE) { DeleteTableFiles(table_id); } ENGINE_LOG_DEBUG << "Successfully delete table, table id = " << table_id; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DELETING TABLE", e.what()); } @@ -770,7 +762,7 @@ MySQLMetaImpl::DeleteTable(const std::string &table_id) { } Status -MySQLMetaImpl::DeleteTableFiles(const std::string &table_id) { +MySQLMetaImpl::DeleteTableFiles(const std::string& table_id) { try { server::MetricCollector metric; { @@ -780,25 +772,24 @@ MySQLMetaImpl::DeleteTableFiles(const std::string &table_id) { return Status(DB_ERROR, "Failed to connect to database server"); } - //soft delete table files + // soft delete table files mysqlpp::Query deleteTableFilesQuery = connectionPtr->query(); // - deleteTableFilesQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << ", " << - "updated_time = " << std::to_string(utils::GetMicroSecTimeStamp()) << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; + deleteTableFilesQuery << "UPDATE " << META_TABLEFILES << " " + << "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << ", " + << "updated_time = " << std::to_string(utils::GetMicroSecTimeStamp()) << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DeleteTableFiles: " << deleteTableFilesQuery.str(); if (!deleteTableFilesQuery.exec()) { return HandleException("QUERY ERROR WHEN DELETING TABLE FILES", deleteTableFilesQuery.error()); } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Successfully delete table files, table id = " << table_id; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DELETING TABLE FILES", e.what()); } @@ -806,7 +797,7 @@ MySQLMetaImpl::DeleteTableFiles(const std::string &table_id) { } Status -MySQLMetaImpl::DescribeTable(TableSchema &table_schema) { +MySQLMetaImpl::DescribeTable(TableSchema& table_schema) { try { server::MetricCollector metric; mysqlpp::StoreQueryResult res; @@ -820,19 +811,19 @@ MySQLMetaImpl::DescribeTable(TableSchema &table_schema) { mysqlpp::Query describeTableQuery = connectionPtr->query(); describeTableQuery << "SELECT id, state, dimension, created_on, flag, index_file_size, engine_type, nlist, metric_type " - << " FROM " << META_TABLES << " " << - "WHERE table_id = " << mysqlpp::quote << table_schema.table_id_ << " " << - "AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; + << " FROM " << META_TABLES << " " + << "WHERE table_id = " << mysqlpp::quote << table_schema.table_id_ << " " + << "AND state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DescribeTable: " << describeTableQuery.str(); res = describeTableQuery.store(); - } //Scoped Connection + } // Scoped Connection if (res.num_rows() == 1) { - const mysqlpp::Row &resRow = res[0]; + const mysqlpp::Row& resRow = res[0]; - table_schema.id_ = resRow["id"]; //implicit conversion + table_schema.id_ = resRow["id"]; // implicit conversion table_schema.state_ = resRow["state"]; @@ -852,7 +843,7 @@ MySQLMetaImpl::DescribeTable(TableSchema &table_schema) { } else { return Status(DB_NOT_FOUND, "Table " + table_schema.table_id_ + " not found"); } - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DESCRIBING TABLE", e.what()); } @@ -860,7 +851,7 @@ MySQLMetaImpl::DescribeTable(TableSchema &table_schema) { } Status -MySQLMetaImpl::HasTable(const std::string &table_id, bool &has_or_not) { +MySQLMetaImpl::HasTable(const std::string& table_id, bool& has_or_not) { try { server::MetricCollector metric; mysqlpp::StoreQueryResult res; @@ -872,22 +863,22 @@ MySQLMetaImpl::HasTable(const std::string &table_id, bool &has_or_not) { } mysqlpp::Query hasTableQuery = connectionPtr->query(); - //since table_id is a unique column we just need to check whether it exists or not - hasTableQuery << "SELECT EXISTS " << - "(SELECT 1 FROM " << - META_TABLES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " " << - "AND state <> " << std::to_string(TableSchema::TO_DELETE) << ") " << - "AS " << mysqlpp::quote << "check" << ";"; + // since table_id is a unique column we just need to check whether it exists or not + hasTableQuery << "SELECT EXISTS " + << "(SELECT 1 FROM " << META_TABLES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " " + << "AND state <> " << std::to_string(TableSchema::TO_DELETE) << ") " + << "AS " << mysqlpp::quote << "check" + << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::HasTable: " << hasTableQuery.str(); res = hasTableQuery.store(); - } //Scoped Connection + } // Scoped Connection int check = res[0]["check"]; has_or_not = (check == 1); - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN CHECKING IF TABLE EXISTS", e.what()); } @@ -895,7 +886,7 @@ MySQLMetaImpl::HasTable(const std::string &table_id, bool &has_or_not) { } Status -MySQLMetaImpl::AllTables(std::vector &table_schema_array) { +MySQLMetaImpl::AllTables(std::vector& table_schema_array) { try { server::MetricCollector metric; mysqlpp::StoreQueryResult res; @@ -908,19 +899,18 @@ MySQLMetaImpl::AllTables(std::vector &table_schema_array) { mysqlpp::Query allTablesQuery = connectionPtr->query(); allTablesQuery << "SELECT id, table_id, dimension, engine_type, nlist, index_file_size, metric_type FROM " - << - META_TABLES << " " << - "WHERE state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; + << META_TABLES << " " + << "WHERE state <> " << std::to_string(TableSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::AllTables: " << allTablesQuery.str(); res = allTablesQuery.store(); - } //Scoped Connection + } // Scoped Connection - for (auto &resRow : res) { + for (auto& resRow : res) { TableSchema table_schema; - table_schema.id_ = resRow["id"]; //implicit conversion + table_schema.id_ = resRow["id"]; // implicit conversion std::string table_id; resRow["table_id"].to_string(table_id); @@ -938,7 +928,7 @@ MySQLMetaImpl::AllTables(std::vector &table_schema_array) { table_schema_array.emplace_back(table_schema); } - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DESCRIBING ALL TABLES", e.what()); } @@ -946,7 +936,7 @@ MySQLMetaImpl::AllTables(std::vector &table_schema_array) { } Status -MySQLMetaImpl::CreateTableFile(TableFileSchema &file_schema) { +MySQLMetaImpl::CreateTableFile(TableFileSchema& file_schema) { if (file_schema.date_ == EmptyDate) { file_schema.date_ = utils::GetDate(); } @@ -971,7 +961,7 @@ MySQLMetaImpl::CreateTableFile(TableFileSchema &file_schema) { file_schema.nlist_ = table_schema.nlist_; file_schema.metric_type_ = table_schema.metric_type_; - std::string id = "NULL"; //auto-increment + std::string id = "NULL"; // auto-increment std::string table_id = file_schema.table_id_; std::string engine_type = std::to_string(file_schema.engine_type_); std::string file_id = file_schema.file_id_; @@ -991,33 +981,31 @@ MySQLMetaImpl::CreateTableFile(TableFileSchema &file_schema) { mysqlpp::Query createTableFileQuery = connectionPtr->query(); - createTableFileQuery << "INSERT INTO " << - META_TABLEFILES << " " << - "VALUES(" << id << ", " << mysqlpp::quote << table_id << - ", " << engine_type << ", " << - mysqlpp::quote << file_id << ", " << file_type << ", " << file_size << ", " << - row_count << ", " << updated_time << ", " << created_on << ", " << date << ");"; + createTableFileQuery << "INSERT INTO " << META_TABLEFILES << " " + << "VALUES(" << id << ", " << mysqlpp::quote << table_id << ", " << engine_type << ", " + << mysqlpp::quote << file_id << ", " << file_type << ", " << file_size << ", " + << row_count << ", " << updated_time << ", " << created_on << ", " << date << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTableFile: " << createTableFileQuery.str(); if (mysqlpp::SimpleResult res = createTableFileQuery.execute()) { - file_schema.id_ = res.insert_id(); //Might need to use SELECT LAST_INSERT_ID()? + file_schema.id_ = res.insert_id(); // Might need to use SELECT LAST_INSERT_ID()? - //Consume all results to avoid "Commands out of sync" error + // Consume all results to avoid "Commands out of sync" error } else { return HandleException("QUERY ERROR WHEN CREATING TABLE FILE", createTableFileQuery.error()); } - } // Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Successfully create table file, file id = " << file_schema.file_id_; return utils::CreateTableFilePath(options_, file_schema); - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN CREATING TABLE FILE", e.what()); } } Status -MySQLMetaImpl::FilesToIndex(TableFilesSchema &files) { +MySQLMetaImpl::FilesToIndex(TableFilesSchema& files) { files.clear(); try { @@ -1033,20 +1021,19 @@ MySQLMetaImpl::FilesToIndex(TableFilesSchema &files) { mysqlpp::Query filesToIndexQuery = connectionPtr->query(); filesToIndexQuery << "SELECT id, table_id, engine_type, file_id, file_type, file_size, row_count, date, created_on FROM " - << - META_TABLEFILES << " " << - "WHERE file_type = " << std::to_string(TableFileSchema::TO_INDEX) << ";"; + << META_TABLEFILES << " " + << "WHERE file_type = " << std::to_string(TableFileSchema::TO_INDEX) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::FilesToIndex: " << filesToIndexQuery.str(); res = filesToIndexQuery.store(); - } //Scoped Connection + } // Scoped Connection Status ret; std::map groups; TableFileSchema table_file; - for (auto &resRow : res) { - table_file.id_ = resRow["id"]; //implicit conversion + for (auto& resRow : res) { + table_file.id_ = resRow["id"]; // implicit conversion std::string table_id; resRow["table_id"].to_string(table_id); @@ -1095,16 +1082,14 @@ MySQLMetaImpl::FilesToIndex(TableFilesSchema &files) { ENGINE_LOG_DEBUG << "Collect " << res.size() << " to-index files"; } return ret; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN FINDING TABLE FILES TO INDEX", e.what()); } } Status -MySQLMetaImpl::FilesToSearch(const std::string &table_id, - const std::vector &ids, - const DatesT &partition, - DatePartionedTableFilesSchema &files) { +MySQLMetaImpl::FilesToSearch(const std::string& table_id, const std::vector& ids, const DatesT& partition, + DatePartionedTableFilesSchema& files) { files.clear(); try { @@ -1119,41 +1104,43 @@ MySQLMetaImpl::FilesToSearch(const std::string &table_id, mysqlpp::Query filesToSearchQuery = connectionPtr->query(); filesToSearchQuery - << "SELECT id, table_id, engine_type, file_id, file_type, file_size, row_count, date FROM " << - META_TABLEFILES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id; + << "SELECT id, table_id, engine_type, file_id, file_type, file_size, row_count, date FROM " + << META_TABLEFILES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id; if (!partition.empty()) { std::stringstream partitionListSS; - for (auto &date : partition) { + for (auto& date : partition) { partitionListSS << std::to_string(date) << ", "; } std::string partitionListStr = partitionListSS.str(); - partitionListStr = partitionListStr.substr(0, partitionListStr.size() - 2); //remove the last ", " - filesToSearchQuery << " AND " << "date IN (" << partitionListStr << ")"; + partitionListStr = partitionListStr.substr(0, partitionListStr.size() - 2); // remove the last ", " + filesToSearchQuery << " AND " + << "date IN (" << partitionListStr << ")"; } if (!ids.empty()) { std::stringstream idSS; - for (auto &id : ids) { + for (auto& id : ids) { idSS << "id = " << std::to_string(id) << " OR "; } std::string idStr = idSS.str(); - idStr = idStr.substr(0, idStr.size() - 4); //remove the last " OR " + idStr = idStr.substr(0, idStr.size() - 4); // remove the last " OR " - filesToSearchQuery << " AND " << "(" << idStr << ")"; + filesToSearchQuery << " AND " + << "(" << idStr << ")"; } // End - filesToSearchQuery << " AND " << - "(file_type = " << std::to_string(TableFileSchema::RAW) << " OR " << - "file_type = " << std::to_string(TableFileSchema::TO_INDEX) << " OR " << - "file_type = " << std::to_string(TableFileSchema::INDEX) << ");"; + filesToSearchQuery << " AND " + << "(file_type = " << std::to_string(TableFileSchema::RAW) << " OR " + << "file_type = " << std::to_string(TableFileSchema::TO_INDEX) << " OR " + << "file_type = " << std::to_string(TableFileSchema::INDEX) << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::FilesToSearch: " << filesToSearchQuery.str(); res = filesToSearchQuery.store(); - } //Scoped Connection + } // Scoped Connection TableSchema table_schema; table_schema.table_id_ = table_id; @@ -1164,8 +1151,8 @@ MySQLMetaImpl::FilesToSearch(const std::string &table_id, Status ret; TableFileSchema table_file; - for (auto &resRow : res) { - table_file.id_ = resRow["id"]; //implicit conversion + for (auto& resRow : res) { + table_file.id_ = resRow["id"]; // implicit conversion std::string table_id_str; resRow["table_id"].to_string(table_id_str); @@ -1210,20 +1197,19 @@ MySQLMetaImpl::FilesToSearch(const std::string &table_id, ENGINE_LOG_DEBUG << "Collect " << res.size() << " to-search files"; } return ret; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN FINDING TABLE FILES TO SEARCH", e.what()); } } Status -MySQLMetaImpl::FilesToMerge(const std::string &table_id, - DatePartionedTableFilesSchema &files) { +MySQLMetaImpl::FilesToMerge(const std::string& table_id, DatePartionedTableFilesSchema& files) { files.clear(); try { server::MetricCollector metric; - //check table existence + // check table existence TableSchema table_schema; table_schema.table_id_ = table_id; auto status = DescribeTable(table_schema); @@ -1242,26 +1228,26 @@ MySQLMetaImpl::FilesToMerge(const std::string &table_id, mysqlpp::Query filesToMergeQuery = connectionPtr->query(); filesToMergeQuery << "SELECT id, table_id, file_id, file_type, file_size, row_count, date, engine_type, created_on FROM " - << - META_TABLEFILES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "file_type = " << std::to_string(TableFileSchema::RAW) << " " << - "ORDER BY row_count DESC" << ";"; + << META_TABLEFILES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "file_type = " << std::to_string(TableFileSchema::RAW) << " " + << "ORDER BY row_count DESC" + << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::FilesToMerge: " << filesToMergeQuery.str(); res = filesToMergeQuery.store(); - } //Scoped Connection + } // Scoped Connection Status ret; - for (auto &resRow : res) { + for (auto& resRow : res) { TableFileSchema table_file; table_file.file_size_ = resRow["file_size"]; if (table_file.file_size_ >= table_schema.index_file_size_) { - continue;//skip large file + continue; // skip large file } - table_file.id_ = resRow["id"]; //implicit conversion + table_file.id_ = resRow["id"]; // implicit conversion std::string table_id_str; resRow["table_id"].to_string(table_id_str); @@ -1306,25 +1292,24 @@ MySQLMetaImpl::FilesToMerge(const std::string &table_id, ENGINE_LOG_DEBUG << "Collect " << res.size() << " to-merge files"; } return ret; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN FINDING TABLE FILES TO MERGE", e.what()); } } Status -MySQLMetaImpl::GetTableFiles(const std::string &table_id, - const std::vector &ids, - TableFilesSchema &table_files) { +MySQLMetaImpl::GetTableFiles(const std::string& table_id, const std::vector& ids, + TableFilesSchema& table_files) { if (ids.empty()) { return Status::OK(); } std::stringstream idSS; - for (auto &id : ids) { + for (auto& id : ids) { idSS << "id = " << std::to_string(id) << " OR "; } std::string idStr = idSS.str(); - idStr = idStr.substr(0, idStr.size() - 4); //remove the last " OR " + idStr = idStr.substr(0, idStr.size() - 4); // remove the last " OR " try { mysqlpp::StoreQueryResult res; @@ -1337,23 +1322,23 @@ MySQLMetaImpl::GetTableFiles(const std::string &table_id, mysqlpp::Query getTableFileQuery = connectionPtr->query(); getTableFileQuery - << "SELECT id, engine_type, file_id, file_type, file_size, row_count, date, created_on FROM " << - META_TABLEFILES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "(" << idStr << ") AND " << - "file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; + << "SELECT id, engine_type, file_id, file_type, file_size, row_count, date, created_on FROM " + << META_TABLEFILES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "(" << idStr << ") AND " + << "file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::GetTableFiles: " << getTableFileQuery.str(); res = getTableFileQuery.store(); - } //Scoped Connection + } // Scoped Connection TableSchema table_schema; table_schema.table_id_ = table_id; DescribeTable(table_schema); Status ret; - for (auto &resRow : res) { + for (auto& resRow : res) { TableFileSchema file_schema; file_schema.id_ = resRow["id"]; @@ -1391,22 +1376,22 @@ MySQLMetaImpl::GetTableFiles(const std::string &table_id, ENGINE_LOG_DEBUG << "Get table files by id"; return ret; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN RETRIEVING TABLE FILES", e.what()); } } -// PXU TODO: Support Swap +// TODO(myh): Support swap to cloud storage Status MySQLMetaImpl::Archive() { - auto &criterias = options_.archive_conf_.GetCriterias(); + auto& criterias = options_.archive_conf_.GetCriterias(); if (criterias.empty()) { return Status::OK(); } - for (auto &kv : criterias) { - auto &criteria = kv.first; - auto &limit = kv.second; + for (auto& kv : criterias) { + auto& criteria = kv.first; + auto& limit = kv.second; if (criteria == engine::ARCHIVE_CONF_DAYS) { size_t usecs = limit * D_SEC * US_PS; int64_t now = utils::GetMicroSecTimeStamp(); @@ -1419,11 +1404,10 @@ MySQLMetaImpl::Archive() { } mysqlpp::Query archiveQuery = connectionPtr->query(); - archiveQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << " " << - "WHERE created_on < " << std::to_string(now - usecs) << " AND " << - "file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; + archiveQuery << "UPDATE " << META_TABLEFILES << " " + << "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << " " + << "WHERE created_on < " << std::to_string(now - usecs) << " AND " + << "file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::Archive: " << archiveQuery.str(); @@ -1432,7 +1416,7 @@ MySQLMetaImpl::Archive() { } ENGINE_LOG_DEBUG << "Archive old files"; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DURING ARCHIVE", e.what()); } } @@ -1451,7 +1435,7 @@ MySQLMetaImpl::Archive() { } Status -MySQLMetaImpl::Size(uint64_t &result) { +MySQLMetaImpl::Size(uint64_t& result) { result = 0; try { @@ -1464,21 +1448,20 @@ MySQLMetaImpl::Size(uint64_t &result) { } mysqlpp::Query getSizeQuery = connectionPtr->query(); - getSizeQuery << "SELECT IFNULL(SUM(file_size),0) AS sum FROM " << - META_TABLEFILES << " " << - "WHERE file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; + getSizeQuery << "SELECT IFNULL(SUM(file_size),0) AS sum FROM " << META_TABLEFILES << " " + << "WHERE file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::Size: " << getSizeQuery.str(); res = getSizeQuery.store(); - } //Scoped Connection + } // Scoped Connection if (res.empty()) { result = 0; } else { result = res[0]["sum"]; } - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN RETRIEVING SIZE", e.what()); } @@ -1503,11 +1486,10 @@ MySQLMetaImpl::DiscardFiles(int64_t to_discard_size) { } mysqlpp::Query discardFilesQuery = connectionPtr->query(); - discardFilesQuery << "SELECT id, file_size FROM " << - META_TABLEFILES << " " << - "WHERE file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << " " << - "ORDER BY id ASC " << - "LIMIT 10;"; + discardFilesQuery << "SELECT id, file_size FROM " << META_TABLEFILES << " " + << "WHERE file_type <> " << std::to_string(TableFileSchema::TO_DELETE) << " " + << "ORDER BY id ASC " + << "LIMIT 10;"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DiscardFiles: " << discardFilesQuery.str(); @@ -1518,7 +1500,7 @@ MySQLMetaImpl::DiscardFiles(int64_t to_discard_size) { TableFileSchema table_file; std::stringstream idsToDiscardSS; - for (auto &resRow : res) { + for (auto& resRow : res) { if (to_discard_size <= 0) { break; } @@ -1531,13 +1513,12 @@ MySQLMetaImpl::DiscardFiles(int64_t to_discard_size) { } std::string idsToDiscardStr = idsToDiscardSS.str(); - idsToDiscardStr = idsToDiscardStr.substr(0, idsToDiscardStr.size() - 4); //remove the last " OR " + idsToDiscardStr = idsToDiscardStr.substr(0, idsToDiscardStr.size() - 4); // remove the last " OR " - discardFilesQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << ", " << - "updated_time = " << std::to_string(utils::GetMicroSecTimeStamp()) << " " << - "WHERE " << idsToDiscardStr << ";"; + discardFilesQuery << "UPDATE " << META_TABLEFILES << " " + << "SET file_type = " << std::to_string(TableFileSchema::TO_DELETE) << ", " + << "updated_time = " << std::to_string(utils::GetMicroSecTimeStamp()) << " " + << "WHERE " << idsToDiscardStr << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::DiscardFiles: " << discardFilesQuery.str(); @@ -1545,17 +1526,17 @@ MySQLMetaImpl::DiscardFiles(int64_t to_discard_size) { if (!status) { return HandleException("QUERY ERROR WHEN DISCARDING FILES", discardFilesQuery.error()); } - } //Scoped Connection + } // Scoped Connection return DiscardFiles(to_discard_size); - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DISCARDING FILES", e.what()); } } -//ZR: this function assumes all fields in file_schema have value +// ZR: this function assumes all fields in file_schema have value Status -MySQLMetaImpl::UpdateTableFile(TableFileSchema &file_schema) { +MySQLMetaImpl::UpdateTableFile(TableFileSchema& file_schema) { file_schema.updated_time_ = utils::GetMicroSecTimeStamp(); try { @@ -1569,11 +1550,10 @@ MySQLMetaImpl::UpdateTableFile(TableFileSchema &file_schema) { mysqlpp::Query updateTableFileQuery = connectionPtr->query(); - //if the table has been deleted, just mark the table file as TO_DELETE - //clean thread will delete the file later - updateTableFileQuery << "SELECT state FROM " << - META_TABLES << " " << - "WHERE table_id = " << mysqlpp::quote << file_schema.table_id_ << ";"; + // if the table has been deleted, just mark the table file as TO_DELETE + // clean thread will delete the file later + updateTableFileQuery << "SELECT state FROM " << META_TABLES << " " + << "WHERE table_id = " << mysqlpp::quote << file_schema.table_id_ << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableFile: " << updateTableFileQuery.str(); @@ -1599,18 +1579,17 @@ MySQLMetaImpl::UpdateTableFile(TableFileSchema &file_schema) { std::string created_on = std::to_string(file_schema.created_on_); std::string date = std::to_string(file_schema.date_); - updateTableFileQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET table_id = " << mysqlpp::quote << table_id << ", " << - "engine_type = " << engine_type << ", " << - "file_id = " << mysqlpp::quote << file_id << ", " << - "file_type = " << file_type << ", " << - "file_size = " << file_size << ", " << - "row_count = " << row_count << ", " << - "updated_time = " << updated_time << ", " << - "created_on = " << created_on << ", " << - "date = " << date << " " << - "WHERE id = " << id << ";"; + updateTableFileQuery << "UPDATE " << META_TABLEFILES << " " + << "SET table_id = " << mysqlpp::quote << table_id << ", " + << "engine_type = " << engine_type << ", " + << "file_id = " << mysqlpp::quote << file_id << ", " + << "file_type = " << file_type << ", " + << "file_size = " << file_size << ", " + << "row_count = " << row_count << ", " + << "updated_time = " << updated_time << ", " + << "created_on = " << created_on << ", " + << "date = " << date << " " + << "WHERE id = " << id << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableFile: " << updateTableFileQuery.str(); @@ -1618,10 +1597,10 @@ MySQLMetaImpl::UpdateTableFile(TableFileSchema &file_schema) { ENGINE_LOG_DEBUG << "table_id= " << file_schema.table_id_ << " file_id=" << file_schema.file_id_; return HandleException("QUERY ERROR WHEN UPDATING TABLE FILE", updateTableFileQuery.error()); } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Update single table file, file id = " << file_schema.file_id_; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN UPDATING TABLE FILE", e.what()); } @@ -1629,7 +1608,7 @@ MySQLMetaImpl::UpdateTableFile(TableFileSchema &file_schema) { } Status -MySQLMetaImpl::UpdateTableFilesToIndex(const std::string &table_id) { +MySQLMetaImpl::UpdateTableFilesToIndex(const std::string& table_id) { try { mysqlpp::ScopedConnection connectionPtr(*mysql_connection_pool_, safe_grab_); @@ -1639,11 +1618,10 @@ MySQLMetaImpl::UpdateTableFilesToIndex(const std::string &table_id) { mysqlpp::Query updateTableFilesToIndexQuery = connectionPtr->query(); - updateTableFilesToIndexQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET file_type = " << std::to_string(TableFileSchema::TO_INDEX) << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "file_type = " << std::to_string(TableFileSchema::RAW) << ";"; + updateTableFilesToIndexQuery << "UPDATE " << META_TABLEFILES << " " + << "SET file_type = " << std::to_string(TableFileSchema::TO_INDEX) << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "file_type = " << std::to_string(TableFileSchema::RAW) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableFilesToIndex: " << updateTableFilesToIndexQuery.str(); @@ -1653,7 +1631,7 @@ MySQLMetaImpl::UpdateTableFilesToIndex(const std::string &table_id) { } ENGINE_LOG_DEBUG << "Update files to to_index, table id = " << table_id; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN UPDATING TABLE FILES TO INDEX", e.what()); } @@ -1661,7 +1639,7 @@ MySQLMetaImpl::UpdateTableFilesToIndex(const std::string &table_id) { } Status -MySQLMetaImpl::UpdateTableFiles(TableFilesSchema &files) { +MySQLMetaImpl::UpdateTableFiles(TableFilesSchema& files) { try { server::MetricCollector metric; { @@ -1674,17 +1652,17 @@ MySQLMetaImpl::UpdateTableFiles(TableFilesSchema &files) { mysqlpp::Query updateTableFilesQuery = connectionPtr->query(); std::map has_tables; - for (auto &file_schema : files) { + for (auto& file_schema : files) { if (has_tables.find(file_schema.table_id_) != has_tables.end()) { continue; } - updateTableFilesQuery << "SELECT EXISTS " << - "(SELECT 1 FROM " << - META_TABLES << " " << - "WHERE table_id = " << mysqlpp::quote << file_schema.table_id_ << " " << - "AND state <> " << std::to_string(TableSchema::TO_DELETE) << ") " << - "AS " << mysqlpp::quote << "check" << ";"; + updateTableFilesQuery << "SELECT EXISTS " + << "(SELECT 1 FROM " << META_TABLES << " " + << "WHERE table_id = " << mysqlpp::quote << file_schema.table_id_ << " " + << "AND state <> " << std::to_string(TableSchema::TO_DELETE) << ") " + << "AS " << mysqlpp::quote << "check" + << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableFiles: " << updateTableFilesQuery.str(); @@ -1694,7 +1672,7 @@ MySQLMetaImpl::UpdateTableFiles(TableFilesSchema &files) { has_tables[file_schema.table_id_] = (check == 1); } - for (auto &file_schema : files) { + for (auto& file_schema : files) { if (!has_tables[file_schema.table_id_]) { file_schema.file_type_ = TableFileSchema::TO_DELETE; } @@ -1711,18 +1689,17 @@ MySQLMetaImpl::UpdateTableFiles(TableFilesSchema &files) { std::string created_on = std::to_string(file_schema.created_on_); std::string date = std::to_string(file_schema.date_); - updateTableFilesQuery << "UPDATE " << - META_TABLEFILES << " " << - "SET table_id = " << mysqlpp::quote << table_id << ", " << - "engine_type = " << engine_type << ", " << - "file_id = " << mysqlpp::quote << file_id << ", " << - "file_type = " << file_type << ", " << - "file_size = " << file_size << ", " << - "row_count = " << row_count << ", " << - "updated_time = " << updated_time << ", " << - "created_on = " << created_on << ", " << - "date = " << date << " " << - "WHERE id = " << id << ";"; + updateTableFilesQuery << "UPDATE " << META_TABLEFILES << " " + << "SET table_id = " << mysqlpp::quote << table_id << ", " + << "engine_type = " << engine_type << ", " + << "file_id = " << mysqlpp::quote << file_id << ", " + << "file_type = " << file_type << ", " + << "file_size = " << file_size << ", " + << "row_count = " << row_count << ", " + << "updated_time = " << updated_time << ", " + << "created_on = " << created_on << ", " + << "date = " << date << " " + << "WHERE id = " << id << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::UpdateTableFiles: " << updateTableFilesQuery.str(); @@ -1730,10 +1707,10 @@ MySQLMetaImpl::UpdateTableFiles(TableFilesSchema &files) { return HandleException("QUERY ERROR WHEN UPDATING TABLE FILES", updateTableFilesQuery.error()); } } - } //Scoped Connection + } // Scoped Connection ENGINE_LOG_DEBUG << "Update " << files.size() << " table files"; - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN UPDATING TABLE FILES", e.what()); } @@ -1745,7 +1722,7 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { auto now = utils::GetMicroSecTimeStamp(); std::set table_ids; - //remove to_delete files + // remove to_delete files try { server::MetricCollector metric; @@ -1757,10 +1734,9 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { } mysqlpp::Query cleanUpFilesWithTTLQuery = connectionPtr->query(); - cleanUpFilesWithTTLQuery << "SELECT id, table_id, file_id, date FROM " << - META_TABLEFILES << " " << - "WHERE file_type = " << std::to_string(TableFileSchema::TO_DELETE) << " AND " << - "updated_time < " << std::to_string(now - seconds * US_PS) << ";"; + cleanUpFilesWithTTLQuery << "SELECT id, table_id, file_id, date FROM " << META_TABLEFILES << " " + << "WHERE file_type = " << std::to_string(TableFileSchema::TO_DELETE) << " AND " + << "updated_time < " << std::to_string(now - seconds * US_PS) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CleanUpFilesWithTTL: " << cleanUpFilesWithTTLQuery.str(); @@ -1769,8 +1745,8 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { TableFileSchema table_file; std::vector idsToDelete; - for (auto &resRow : res) { - table_file.id_ = resRow["id"]; //implicit conversion + for (auto& resRow : res) { + table_file.id_ = resRow["id"]; // implicit conversion std::string table_id; resRow["table_id"].to_string(table_id); @@ -1793,15 +1769,14 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { if (!idsToDelete.empty()) { std::stringstream idsToDeleteSS; - for (auto &id : idsToDelete) { + for (auto& id : idsToDelete) { idsToDeleteSS << "id = " << id << " OR "; } std::string idsToDeleteStr = idsToDeleteSS.str(); - idsToDeleteStr = idsToDeleteStr.substr(0, idsToDeleteStr.size() - 4); //remove the last " OR " - cleanUpFilesWithTTLQuery << "DELETE FROM " << - META_TABLEFILES << " " << - "WHERE " << idsToDeleteStr << ";"; + idsToDeleteStr = idsToDeleteStr.substr(0, idsToDeleteStr.size() - 4); // remove the last " OR " + cleanUpFilesWithTTLQuery << "DELETE FROM " << META_TABLEFILES << " " + << "WHERE " << idsToDeleteStr << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CleanUpFilesWithTTL: " << cleanUpFilesWithTTLQuery.str(); @@ -1814,12 +1789,12 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { if (res.size() > 0) { ENGINE_LOG_DEBUG << "Clean " << res.size() << " files deleted in " << seconds << " seconds"; } - } //Scoped Connection - } catch (std::exception &e) { + } // Scoped Connection + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN CLEANING UP FILES WITH TTL", e.what()); } - //remove to_delete tables + // remove to_delete tables try { server::MetricCollector metric; @@ -1831,9 +1806,8 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { } mysqlpp::Query cleanUpFilesWithTTLQuery = connectionPtr->query(); - cleanUpFilesWithTTLQuery << "SELECT id, table_id FROM " << - META_TABLES << " " << - "WHERE state = " << std::to_string(TableSchema::TO_DELETE) << ";"; + cleanUpFilesWithTTLQuery << "SELECT id, table_id FROM " << META_TABLES << " " + << "WHERE state = " << std::to_string(TableSchema::TO_DELETE) << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CleanUpFilesWithTTL: " << cleanUpFilesWithTTLQuery.str(); @@ -1841,20 +1815,19 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { if (!res.empty()) { std::stringstream idsToDeleteSS; - for (auto &resRow : res) { + for (auto& resRow : res) { size_t id = resRow["id"]; std::string table_id; resRow["table_id"].to_string(table_id); - utils::DeleteTablePath(options_, table_id, false);//only delete empty folder + utils::DeleteTablePath(options_, table_id, false); // only delete empty folder idsToDeleteSS << "id = " << std::to_string(id) << " OR "; } std::string idsToDeleteStr = idsToDeleteSS.str(); - idsToDeleteStr = idsToDeleteStr.substr(0, idsToDeleteStr.size() - 4); //remove the last " OR " - cleanUpFilesWithTTLQuery << "DELETE FROM " << - META_TABLES << " " << - "WHERE " << idsToDeleteStr << ";"; + idsToDeleteStr = idsToDeleteStr.substr(0, idsToDeleteStr.size() - 4); // remove the last " OR " + cleanUpFilesWithTTLQuery << "DELETE FROM " << META_TABLES << " " + << "WHERE " << idsToDeleteStr << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CleanUpFilesWithTTL: " << cleanUpFilesWithTTLQuery.str(); @@ -1867,13 +1840,13 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { if (res.size() > 0) { ENGINE_LOG_DEBUG << "Remove " << res.size() << " tables from meta"; } - } //Scoped Connection - } catch (std::exception &e) { + } // Scoped Connection + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN CLEANING UP TABLES WITH TTL", e.what()); } - //remove deleted table folder - //don't remove table folder until all its files has been deleted + // remove deleted table folder + // don't remove table folder until all its files has been deleted try { server::MetricCollector metric; @@ -1884,11 +1857,10 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { return Status(DB_ERROR, "Failed to connect to database server"); } - for (auto &table_id : table_ids) { + for (auto& table_id : table_ids) { mysqlpp::Query cleanUpFilesWithTTLQuery = connectionPtr->query(); - cleanUpFilesWithTTLQuery << "SELECT file_id FROM " << - META_TABLEFILES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << ";"; + cleanUpFilesWithTTLQuery << "SELECT file_id FROM " << META_TABLEFILES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CleanUpFilesWithTTL: " << cleanUpFilesWithTTLQuery.str(); @@ -1903,7 +1875,7 @@ MySQLMetaImpl::CleanUpFilesWithTTL(uint16_t seconds) { ENGINE_LOG_DEBUG << "Remove " << table_ids.size() << " tables folder"; } } - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN CLEANING UP TABLES WITH TTL", e.what()); } @@ -1920,10 +1892,10 @@ MySQLMetaImpl::CleanUp() { } mysqlpp::Query cleanUpQuery = connectionPtr->query(); - cleanUpQuery << "SELECT table_name " << - "FROM information_schema.tables " << - "WHERE table_schema = " << mysqlpp::quote << mysql_connection_pool_->getDB() << " " << - "AND table_name = " << mysqlpp::quote << META_TABLEFILES << ";"; + cleanUpQuery << "SELECT table_name " + << "FROM information_schema.tables " + << "WHERE table_schema = " << mysqlpp::quote << mysql_connection_pool_->getDB() << " " + << "AND table_name = " << mysqlpp::quote << META_TABLEFILES << ";"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CleanUp: " << cleanUpQuery.str(); @@ -1932,9 +1904,8 @@ MySQLMetaImpl::CleanUp() { if (!res.empty()) { ENGINE_LOG_DEBUG << "Remove table file type as NEW"; cleanUpQuery << "DELETE FROM " << META_TABLEFILES << " WHERE file_type IN (" - << std::to_string(TableFileSchema::NEW) << "," - << std::to_string(TableFileSchema::NEW_MERGE) << "," - << std::to_string(TableFileSchema::NEW_INDEX) << ");"; + << std::to_string(TableFileSchema::NEW) << "," << std::to_string(TableFileSchema::NEW_MERGE) + << "," << std::to_string(TableFileSchema::NEW_INDEX) << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CleanUp: " << cleanUpQuery.str(); @@ -1946,7 +1917,7 @@ MySQLMetaImpl::CleanUp() { if (res.size() > 0) { ENGINE_LOG_DEBUG << "Clean " << res.size() << " files"; } - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN CLEANING UP FILES", e.what()); } @@ -1954,7 +1925,7 @@ MySQLMetaImpl::CleanUp() { } Status -MySQLMetaImpl::Count(const std::string &table_id, uint64_t &result) { +MySQLMetaImpl::Count(const std::string& table_id, uint64_t& result) { try { server::MetricCollector metric; @@ -1975,24 +1946,23 @@ MySQLMetaImpl::Count(const std::string &table_id, uint64_t &result) { } mysqlpp::Query countQuery = connectionPtr->query(); - countQuery << "SELECT row_count FROM " << - META_TABLEFILES << " " << - "WHERE table_id = " << mysqlpp::quote << table_id << " AND " << - "(file_type = " << std::to_string(TableFileSchema::RAW) << " OR " << - "file_type = " << std::to_string(TableFileSchema::TO_INDEX) << " OR " << - "file_type = " << std::to_string(TableFileSchema::INDEX) << ");"; + countQuery << "SELECT row_count FROM " << META_TABLEFILES << " " + << "WHERE table_id = " << mysqlpp::quote << table_id << " AND " + << "(file_type = " << std::to_string(TableFileSchema::RAW) << " OR " + << "file_type = " << std::to_string(TableFileSchema::TO_INDEX) << " OR " + << "file_type = " << std::to_string(TableFileSchema::INDEX) << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::Count: " << countQuery.str(); res = countQuery.store(); - } //Scoped Connection + } // Scoped Connection result = 0; - for (auto &resRow : res) { + for (auto& resRow : res) { size_t size = resRow["row_count"]; result += size; } - } catch (std::exception &e) { + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN RETRIEVING COUNT", e.what()); } @@ -2016,15 +1986,13 @@ MySQLMetaImpl::DropAll() { if (dropTableQuery.exec()) { return Status::OK(); - } else { - return HandleException("QUERY ERROR WHEN DROPPING ALL", dropTableQuery.error()); } - } catch (std::exception &e) { + return HandleException("QUERY ERROR WHEN DROPPING ALL", dropTableQuery.error()); + } catch (std::exception& e) { return HandleException("GENERAL ERROR WHEN DROPPING ALL", e.what()); } } -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/MySQLMetaImpl.h b/cpp/src/db/meta/MySQLMetaImpl.h index d9d67fd748d350f9ed8645d4d137c8ed1dd3a2f7..7ca66bc992adb80f10fee6a4d2c92a2571e83588 100644 --- a/cpp/src/db/meta/MySQLMetaImpl.h +++ b/cpp/src/db/meta/MySQLMetaImpl.h @@ -18,93 +18,116 @@ #pragma once #include "Meta.h" -#include "db/Options.h" #include "MySQLConnectionPool.h" +#include "db/Options.h" #include +#include #include -#include #include -#include +#include -namespace zilliz { namespace milvus { namespace engine { namespace meta { class MySQLMetaImpl : public Meta { public: - MySQLMetaImpl(const DBMetaOptions &options, const int &mode); + MySQLMetaImpl(const DBMetaOptions& options, const int& mode); ~MySQLMetaImpl(); - Status CreateTable(TableSchema &table_schema) override; + Status + CreateTable(TableSchema& table_schema) override; - Status DescribeTable(TableSchema &table_schema) override; + Status + DescribeTable(TableSchema& table_schema) override; - Status HasTable(const std::string &table_id, bool &has_or_not) override; + Status + HasTable(const std::string& table_id, bool& has_or_not) override; - Status AllTables(std::vector &table_schema_array) override; + Status + AllTables(std::vector& table_schema_array) override; - Status DeleteTable(const std::string &table_id) override; + Status + DeleteTable(const std::string& table_id) override; - Status DeleteTableFiles(const std::string &table_id) override; + Status + DeleteTableFiles(const std::string& table_id) override; - Status CreateTableFile(TableFileSchema &file_schema) override; + Status + CreateTableFile(TableFileSchema& file_schema) override; - Status DropPartitionsByDates(const std::string &table_id, - const DatesT &dates) override; + Status + DropPartitionsByDates(const std::string& table_id, const DatesT& dates) override; - Status GetTableFiles(const std::string &table_id, - const std::vector &ids, - TableFilesSchema &table_files) override; + Status + GetTableFiles(const std::string& table_id, const std::vector& ids, TableFilesSchema& table_files) override; - Status FilesByType(const std::string &table_id, - const std::vector &file_types, - std::vector &file_ids) override; + Status + FilesByType(const std::string& table_id, const std::vector& file_types, + std::vector& file_ids) override; - Status UpdateTableIndex(const std::string &table_id, const TableIndex &index) override; + Status + UpdateTableIndex(const std::string& table_id, const TableIndex& index) override; - Status UpdateTableFlag(const std::string &table_id, int64_t flag) override; + Status + UpdateTableFlag(const std::string& table_id, int64_t flag) override; - Status DescribeTableIndex(const std::string &table_id, TableIndex &index) override; + Status + DescribeTableIndex(const std::string& table_id, TableIndex& index) override; - Status DropTableIndex(const std::string &table_id) override; + Status + DropTableIndex(const std::string& table_id) override; - Status UpdateTableFile(TableFileSchema &file_schema) override; + Status + UpdateTableFile(TableFileSchema& file_schema) override; - Status UpdateTableFilesToIndex(const std::string &table_id) override; + Status + UpdateTableFilesToIndex(const std::string& table_id) override; - Status UpdateTableFiles(TableFilesSchema &files) override; + Status + UpdateTableFiles(TableFilesSchema& files) override; - Status FilesToSearch(const std::string &table_id, - const std::vector &ids, - const DatesT &partition, - DatePartionedTableFilesSchema &files) override; + Status + FilesToSearch(const std::string& table_id, const std::vector& ids, const DatesT& partition, + DatePartionedTableFilesSchema& files) override; - Status FilesToMerge(const std::string &table_id, - DatePartionedTableFilesSchema &files) override; + Status + FilesToMerge(const std::string& table_id, DatePartionedTableFilesSchema& files) override; - Status FilesToIndex(TableFilesSchema &) override; + Status + FilesToIndex(TableFilesSchema&) override; - Status Archive() override; + Status + Archive() override; - Status Size(uint64_t &result) override; + Status + Size(uint64_t& result) override; - Status CleanUp() override; + Status + CleanUp() override; - Status CleanUpFilesWithTTL(uint16_t seconds) override; + Status + CleanUpFilesWithTTL(uint16_t seconds) override; - Status DropAll() override; + Status + DropAll() override; - Status Count(const std::string &table_id, uint64_t &result) override; + Status + Count(const std::string& table_id, uint64_t& result) override; private: - Status NextFileId(std::string &file_id); - Status NextTableId(std::string &table_id); - Status DiscardFiles(int64_t to_discard_size); - - void ValidateMetaSchema(); - Status Initialize(); + Status + NextFileId(std::string& file_id); + Status + NextTableId(std::string& table_id); + Status + DiscardFiles(int64_t to_discard_size); + + void + ValidateMetaSchema(); + Status + Initialize(); private: const DBMetaOptions options_; @@ -113,10 +136,9 @@ class MySQLMetaImpl : public Meta { std::shared_ptr mysql_connection_pool_; bool safe_grab_ = false; -// std::mutex connectionMutex_; -}; // DBMetaImpl + // std::mutex connectionMutex_; +}; // DBMetaImpl -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/db/meta/SqliteMetaImpl.cpp b/cpp/src/db/meta/SqliteMetaImpl.cpp index 471b182981df2d5a183f78b6f1d72064088e4067..dd9bb6fd300bff7dc15470c61ad92c92f0a1b2e8 100644 --- a/cpp/src/db/meta/SqliteMetaImpl.cpp +++ b/cpp/src/db/meta/SqliteMetaImpl.cpp @@ -34,7 +34,7 @@ #include #include -namespace zilliz { + namespace milvus { namespace engine { namespace meta { @@ -155,7 +155,7 @@ SqliteMetaImpl::Initialize() { return Status::OK(); } -// PXU TODO: Temp solution. Will fix later +// TODO(myh): Delete single vecotor by id Status SqliteMetaImpl::DropPartitionsByDates(const std::string &table_id, const DatesT &dates) { @@ -885,7 +885,7 @@ SqliteMetaImpl::GetTableFiles(const std::string &table_id, } } -// PXU TODO: Support Swap +// TODO(myh): Support swap to cloud storage Status SqliteMetaImpl::Archive() { auto &criterias = options_.archive_conf_.GetCriterias(); @@ -1298,4 +1298,4 @@ SqliteMetaImpl::DropAll() { } // namespace meta } // namespace engine } // namespace milvus -} // namespace zilliz + diff --git a/cpp/src/db/meta/SqliteMetaImpl.h b/cpp/src/db/meta/SqliteMetaImpl.h index 978e5e5a4baf974c49e2d2cc1935445742c65bfe..dc132c41ec5dd3b9e675898b0deaae64a227c5cf 100644 --- a/cpp/src/db/meta/SqliteMetaImpl.h +++ b/cpp/src/db/meta/SqliteMetaImpl.h @@ -21,95 +21,119 @@ #include "db/Options.h" #include -#include #include +#include -namespace zilliz { namespace milvus { namespace engine { namespace meta { auto -StoragePrototype(const std::string &path); +StoragePrototype(const std::string& path); class SqliteMetaImpl : public Meta { public: - explicit SqliteMetaImpl(const DBMetaOptions &options); + explicit SqliteMetaImpl(const DBMetaOptions& options); ~SqliteMetaImpl(); - Status CreateTable(TableSchema &table_schema) override; + Status + CreateTable(TableSchema& table_schema) override; - Status DescribeTable(TableSchema &table_schema) override; + Status + DescribeTable(TableSchema& table_schema) override; - Status HasTable(const std::string &table_id, bool &has_or_not) override; + Status + HasTable(const std::string& table_id, bool& has_or_not) override; - Status AllTables(std::vector &table_schema_array) override; + Status + AllTables(std::vector& table_schema_array) override; - Status DeleteTable(const std::string &table_id) override; + Status + DeleteTable(const std::string& table_id) override; - Status DeleteTableFiles(const std::string &table_id) override; + Status + DeleteTableFiles(const std::string& table_id) override; - Status CreateTableFile(TableFileSchema &file_schema) override; + Status + CreateTableFile(TableFileSchema& file_schema) override; - Status DropPartitionsByDates(const std::string &table_id, const DatesT &dates) override; + Status + DropPartitionsByDates(const std::string& table_id, const DatesT& dates) override; - Status GetTableFiles(const std::string &table_id, - const std::vector &ids, - TableFilesSchema &table_files) override; + Status + GetTableFiles(const std::string& table_id, const std::vector& ids, TableFilesSchema& table_files) override; - Status FilesByType(const std::string &table_id, - const std::vector &file_types, - std::vector &file_ids) override; + Status + FilesByType(const std::string& table_id, const std::vector& file_types, + std::vector& file_ids) override; - Status UpdateTableIndex(const std::string &table_id, const TableIndex &index) override; + Status + UpdateTableIndex(const std::string& table_id, const TableIndex& index) override; - Status UpdateTableFlag(const std::string &table_id, int64_t flag) override; + Status + UpdateTableFlag(const std::string& table_id, int64_t flag) override; - Status DescribeTableIndex(const std::string &table_id, TableIndex &index) override; + Status + DescribeTableIndex(const std::string& table_id, TableIndex& index) override; - Status DropTableIndex(const std::string &table_id) override; + Status + DropTableIndex(const std::string& table_id) override; - Status UpdateTableFilesToIndex(const std::string &table_id) override; + Status + UpdateTableFilesToIndex(const std::string& table_id) override; - Status UpdateTableFile(TableFileSchema &file_schema) override; + Status + UpdateTableFile(TableFileSchema& file_schema) override; - Status UpdateTableFiles(TableFilesSchema &files) override; + Status + UpdateTableFiles(TableFilesSchema& files) override; - Status FilesToSearch(const std::string &table_id, - const std::vector &ids, - const DatesT &partition, - DatePartionedTableFilesSchema &files) override; + Status + FilesToSearch(const std::string& table_id, const std::vector& ids, const DatesT& partition, + DatePartionedTableFilesSchema& files) override; - Status FilesToMerge(const std::string &table_id, DatePartionedTableFilesSchema &files) override; + Status + FilesToMerge(const std::string& table_id, DatePartionedTableFilesSchema& files) override; - Status FilesToIndex(TableFilesSchema &) override; + Status + FilesToIndex(TableFilesSchema&) override; - Status Archive() override; + Status + Archive() override; - Status Size(uint64_t &result) override; + Status + Size(uint64_t& result) override; - Status CleanUp() override; + Status + CleanUp() override; - Status CleanUpFilesWithTTL(uint16_t seconds) override; + Status + CleanUpFilesWithTTL(uint16_t seconds) override; - Status DropAll() override; + Status + DropAll() override; - Status Count(const std::string &table_id, uint64_t &result) override; + Status + Count(const std::string& table_id, uint64_t& result) override; private: - Status NextFileId(std::string &file_id); - Status NextTableId(std::string &table_id); - Status DiscardFiles(int64_t to_discard_size); - - void ValidateMetaSchema(); - Status Initialize(); + Status + NextFileId(std::string& file_id); + Status + NextTableId(std::string& table_id); + Status + DiscardFiles(int64_t to_discard_size); + + void + ValidateMetaSchema(); + Status + Initialize(); private: const DBMetaOptions options_; std::mutex meta_mutex_; -}; // DBMetaImpl +}; // DBMetaImpl -} // namespace meta -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace meta +} // namespace engine +} // namespace milvus diff --git a/cpp/src/grpc/cpp_gen.sh b/cpp/src/grpc/cpp_gen.sh old mode 100755 new mode 100644 diff --git a/cpp/src/main.cpp b/cpp/src/main.cpp index 22c282d4ea0f4150dfa406a3ae37a0730844c823..b50eedeabaec2edd6dedeffb805bb07d1628c95b 100644 --- a/cpp/src/main.cpp +++ b/cpp/src/main.cpp @@ -17,34 +17,34 @@ #include #include -#include -#include #include #include +#include +#include -#include "utils/easylogging++.h" -#include "utils/SignalUtil.h" -#include "utils/CommonUtil.h" +#include "../version.h" #include "metrics/Metrics.h" #include "server/Server.h" -#include "../version.h" - +#include "utils/CommonUtil.h" +#include "utils/SignalUtil.h" +#include "utils/easylogging++.h" INITIALIZE_EASYLOGGINGPP -void print_help(const std::string &app_name); +void +print_help(const std::string& app_name); int -main(int argc, char *argv[]) { +main(int argc, char* argv[]) { std::cout << std::endl << "Welcome to use Milvus by Zilliz!" << std::endl; std::cout << "Milvus " << BUILD_TYPE << " version: v" << MILVUS_VERSION << " built at " << BUILD_TIME << std::endl; - static struct option long_options[] = {{"conf_file", required_argument, 0, 'c'}, - {"log_conf_file", required_argument, 0, 'l'}, - {"help", no_argument, 0, 'h'}, - {"daemon", no_argument, 0, 'd'}, - {"pid_file", required_argument, 0, 'p'}, - {NULL, 0, 0, 0}}; + static struct option long_options[] = {{"conf_file", required_argument, nullptr, 'c'}, + {"log_conf_file", required_argument, nullptr, 'l'}, + {"help", no_argument, nullptr, 'h'}, + {"daemon", no_argument, nullptr, 'd'}, + {"pid_file", required_argument, nullptr, 'p'}, + {nullptr, 0, nullptr, 0}}; int option_index = 0; int64_t start_daemonized = 0; @@ -63,21 +63,21 @@ main(int argc, char *argv[]) { while ((value = getopt_long(argc, argv, "c:l:p:dh", long_options, &option_index)) != -1) { switch (value) { case 'c': { - char *config_filename_ptr = strdup(optarg); + char* config_filename_ptr = strdup(optarg); config_filename = config_filename_ptr; free(config_filename_ptr); std::cout << "Loading configuration from: " << config_filename << std::endl; break; } case 'l': { - char *log_filename_ptr = strdup(optarg); + char* log_filename_ptr = strdup(optarg); log_config_file = log_filename_ptr; free(log_filename_ptr); std::cout << "Initial log config from: " << log_config_file << std::endl; break; } case 'p': { - char *pid_filename_ptr = strdup(optarg); + char* pid_filename_ptr = strdup(optarg); pid_filename = pid_filename_ptr; free(pid_filename_ptr); std::cout << pid_filename << std::endl; @@ -99,14 +99,14 @@ main(int argc, char *argv[]) { } /* Handle Signal */ - signal(SIGHUP, zilliz::milvus::server::SignalUtil::HandleSignal); - signal(SIGINT, zilliz::milvus::server::SignalUtil::HandleSignal); - signal(SIGUSR1, zilliz::milvus::server::SignalUtil::HandleSignal); - signal(SIGSEGV, zilliz::milvus::server::SignalUtil::HandleSignal); - signal(SIGUSR2, zilliz::milvus::server::SignalUtil::HandleSignal); - signal(SIGTERM, zilliz::milvus::server::SignalUtil::HandleSignal); - - zilliz::milvus::server::Server &server = zilliz::milvus::server::Server::GetInstance(); + signal(SIGHUP, milvus::server::SignalUtil::HandleSignal); + signal(SIGINT, milvus::server::SignalUtil::HandleSignal); + signal(SIGUSR1, milvus::server::SignalUtil::HandleSignal); + signal(SIGSEGV, milvus::server::SignalUtil::HandleSignal); + signal(SIGUSR2, milvus::server::SignalUtil::HandleSignal); + signal(SIGTERM, milvus::server::SignalUtil::HandleSignal); + + milvus::server::Server& server = milvus::server::Server::GetInstance(); server.Init(start_daemonized, pid_filename, config_filename, log_config_file); server.Start(); @@ -117,7 +117,7 @@ main(int argc, char *argv[]) { } void -print_help(const std::string &app_name) { +print_help(const std::string& app_name) { std::cout << std::endl << "Usage: " << app_name << " [OPTIONS]" << std::endl << std::endl; std::cout << " Options:" << std::endl; std::cout << " -h --help Print this help" << std::endl; diff --git a/cpp/src/metrics/MetricBase.h b/cpp/src/metrics/MetricBase.h index 32a0e715e61f66c4e6711bcfcea877a2865129ff..eeca45e78923bb60678a0b464f78f0090422b82d 100644 --- a/cpp/src/metrics/MetricBase.h +++ b/cpp/src/metrics/MetricBase.h @@ -15,155 +15,195 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "utils/Error.h" #include "SystemInfo.h" +#include "utils/Error.h" #include -namespace zilliz { namespace milvus { namespace server { class MetricsBase { public: - static MetricsBase & + static MetricsBase& GetInstance() { static MetricsBase instance; return instance; } - virtual ErrorCode Init() { + virtual ErrorCode + Init() { } - virtual void AddVectorsSuccessTotalIncrement(double value = 1) { + virtual void + AddVectorsSuccessTotalIncrement(double value = 1) { } - virtual void AddVectorsFailTotalIncrement(double value = 1) { + virtual void + AddVectorsFailTotalIncrement(double value = 1) { } - virtual void AddVectorsDurationHistogramOberve(double value) { + virtual void + AddVectorsDurationHistogramOberve(double value) { } - virtual void RawFileSizeHistogramObserve(double value) { + virtual void + RawFileSizeHistogramObserve(double value) { } - virtual void IndexFileSizeHistogramObserve(double value) { + virtual void + IndexFileSizeHistogramObserve(double value) { } - virtual void BuildIndexDurationSecondsHistogramObserve(double value) { + virtual void + BuildIndexDurationSecondsHistogramObserve(double value) { } - virtual void CpuCacheUsageGaugeSet(double value) { + virtual void + CpuCacheUsageGaugeSet(double value) { } - virtual void GpuCacheUsageGaugeSet() { + virtual void + GpuCacheUsageGaugeSet() { } - virtual void MetaAccessTotalIncrement(double value = 1) { + virtual void + MetaAccessTotalIncrement(double value = 1) { } - virtual void MetaAccessDurationSecondsHistogramObserve(double value) { + virtual void + MetaAccessDurationSecondsHistogramObserve(double value) { } - virtual void FaissDiskLoadDurationSecondsHistogramObserve(double value) { + virtual void + FaissDiskLoadDurationSecondsHistogramObserve(double value) { } - virtual void FaissDiskLoadSizeBytesHistogramObserve(double value) { + virtual void + FaissDiskLoadSizeBytesHistogramObserve(double value) { } - virtual void CacheAccessTotalIncrement(double value = 1) { + virtual void + CacheAccessTotalIncrement(double value = 1) { } - virtual void MemTableMergeDurationSecondsHistogramObserve(double value) { + virtual void + MemTableMergeDurationSecondsHistogramObserve(double value) { } - virtual void SearchIndexDataDurationSecondsHistogramObserve(double value) { + virtual void + SearchIndexDataDurationSecondsHistogramObserve(double value) { } - virtual void SearchRawDataDurationSecondsHistogramObserve(double value) { + virtual void + SearchRawDataDurationSecondsHistogramObserve(double value) { } - virtual void IndexFileSizeTotalIncrement(double value = 1) { + virtual void + IndexFileSizeTotalIncrement(double value = 1) { } - virtual void RawFileSizeTotalIncrement(double value = 1) { + virtual void + RawFileSizeTotalIncrement(double value = 1) { } - virtual void IndexFileSizeGaugeSet(double value) { + virtual void + IndexFileSizeGaugeSet(double value) { } - virtual void RawFileSizeGaugeSet(double value) { + virtual void + RawFileSizeGaugeSet(double value) { } - virtual void FaissDiskLoadIOSpeedGaugeSet(double value) { + virtual void + FaissDiskLoadIOSpeedGaugeSet(double value) { } - virtual void QueryResponseSummaryObserve(double value) { + virtual void + QueryResponseSummaryObserve(double value) { } - virtual void DiskStoreIOSpeedGaugeSet(double value) { + virtual void + DiskStoreIOSpeedGaugeSet(double value) { } - virtual void DataFileSizeGaugeSet(double value) { + virtual void + DataFileSizeGaugeSet(double value) { } - virtual void AddVectorsSuccessGaugeSet(double value) { + virtual void + AddVectorsSuccessGaugeSet(double value) { } - virtual void AddVectorsFailGaugeSet(double value) { + virtual void + AddVectorsFailGaugeSet(double value) { } - virtual void QueryVectorResponseSummaryObserve(double value, int count = 1) { + virtual void + QueryVectorResponseSummaryObserve(double value, int count = 1) { } - virtual void QueryVectorResponsePerSecondGaugeSet(double value) { + virtual void + QueryVectorResponsePerSecondGaugeSet(double value) { } - virtual void CPUUsagePercentSet() { + virtual void + CPUUsagePercentSet() { } - virtual void RAMUsagePercentSet() { + virtual void + RAMUsagePercentSet() { } - virtual void QueryResponsePerSecondGaugeSet(double value) { + virtual void + QueryResponsePerSecondGaugeSet(double value) { } - virtual void GPUPercentGaugeSet() { + virtual void + GPUPercentGaugeSet() { } - virtual void GPUMemoryUsageGaugeSet() { + virtual void + GPUMemoryUsageGaugeSet() { } - virtual void AddVectorsPerSecondGaugeSet(int num_vector, int dim, double time) { + virtual void + AddVectorsPerSecondGaugeSet(int num_vector, int dim, double time) { } - virtual void QueryIndexTypePerSecondSet(std::string type, double value) { + virtual void + QueryIndexTypePerSecondSet(std::string type, double value) { } - virtual void ConnectionGaugeIncrement() { + virtual void + ConnectionGaugeIncrement() { } - virtual void ConnectionGaugeDecrement() { + virtual void + ConnectionGaugeDecrement() { } - virtual void KeepingAliveCounterIncrement(double value = 1) { + virtual void + KeepingAliveCounterIncrement(double value = 1) { } - virtual void OctetsSet() { + virtual void + OctetsSet() { } - virtual void CPUCoreUsagePercentSet() { + virtual void + CPUCoreUsagePercentSet() { } - virtual void GPUTemperature() { + virtual void + GPUTemperature() { } - virtual void CPUTemperature() { + virtual void + CPUTemperature() { } }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/metrics/Metrics.cpp b/cpp/src/metrics/Metrics.cpp index 0a1b333b4272d3f5d0bb9accee4527c905e11058..51db5555b8760ad945a2779fbd0c23e9ac1c935c 100644 --- a/cpp/src/metrics/Metrics.cpp +++ b/cpp/src/metrics/Metrics.cpp @@ -16,24 +16,23 @@ // under the License. #include "metrics/Metrics.h" -#include "server/Config.h" #include "PrometheusMetrics.h" +#include "server/Config.h" #include -namespace zilliz { namespace milvus { namespace server { -MetricsBase & +MetricsBase& Metrics::GetInstance() { - static MetricsBase &instance = CreateMetricsCollector(); + static MetricsBase& instance = CreateMetricsCollector(); return instance; } -MetricsBase & +MetricsBase& Metrics::CreateMetricsCollector() { - Config &config = Config::GetInstance(); + Config& config = Config::GetInstance(); std::string collector_type_str; config.GetMetricConfigCollector(collector_type_str); @@ -45,6 +44,5 @@ Metrics::CreateMetricsCollector() { } } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/metrics/Metrics.h b/cpp/src/metrics/Metrics.h index 66e06b8ba1d115a8c7151effbd48ef1489c407dc..c207a50d9e67ec398732b30cc36b784bc1188878 100644 --- a/cpp/src/metrics/Metrics.h +++ b/cpp/src/metrics/Metrics.h @@ -15,31 +15,27 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "MetricBase.h" #include "db/meta/MetaTypes.h" -namespace zilliz { namespace milvus { namespace server { #define METRICS_NOW_TIME std::chrono::system_clock::now() -#define METRICS_MICROSECONDS(a, b) (std::chrono::duration_cast (b-a)).count(); +#define METRICS_MICROSECONDS(a, b) (std::chrono::duration_cast(b - a)).count(); -enum class MetricCollectorType { - INVALID, - PROMETHEUS, - ZABBIX -}; +enum class MetricCollectorType { INVALID, PROMETHEUS, ZABBIX }; class Metrics { public: - static MetricsBase &GetInstance(); + static MetricsBase& + GetInstance(); private: - static MetricsBase &CreateMetricsCollector(); + static MetricsBase& + CreateMetricsCollector(); }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class CollectMetricsBase { @@ -50,7 +46,8 @@ class CollectMetricsBase { virtual ~CollectMetricsBase() = default; - double TimeFromBegine() { + double + TimeFromBegine() { auto end_time = METRICS_NOW_TIME; return METRICS_MICROSECONDS(start_time_, end_time); } @@ -63,7 +60,7 @@ class CollectMetricsBase { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class CollectInsertMetrics : CollectMetricsBase { public: - CollectInsertMetrics(size_t n, Status &status) : n_(n), status_(status) { + CollectInsertMetrics(size_t n, Status& status) : n_(n), status_(status) { } ~CollectInsertMetrics() { @@ -87,7 +84,7 @@ class CollectInsertMetrics : CollectMetricsBase { private: size_t n_; - Status &status_; + Status& status_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -162,7 +159,7 @@ class CollectSerializeMetrics : CollectMetricsBase { ~CollectSerializeMetrics() { auto total_time = TimeFromBegine(); - server::Metrics::GetInstance().DiskStoreIOSpeedGaugeSet((double) size_ / total_time); + server::Metrics::GetInstance().DiskStoreIOSpeedGaugeSet((double)size_ / total_time); } private: @@ -177,8 +174,7 @@ class CollectAddMetrics : CollectMetricsBase { ~CollectAddMetrics() { auto total_time = TimeFromBegine(); - server::Metrics::GetInstance().AddVectorsPerSecondGaugeSet(static_cast(n_), - static_cast(dimension_), + server::Metrics::GetInstance().AddVectorsPerSecondGaugeSet(static_cast(n_), static_cast(dimension_), total_time); } @@ -256,6 +252,5 @@ class MetricCollector : CollectMetricsBase { } }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/metrics/PrometheusMetrics.cpp b/cpp/src/metrics/PrometheusMetrics.cpp index 8dcc97aa8b943b9b783d381ad3ec8cae629db9a5..182f14d46c6da125f840243ffd41a7d14c296f95 100644 --- a/cpp/src/metrics/PrometheusMetrics.cpp +++ b/cpp/src/metrics/PrometheusMetrics.cpp @@ -15,33 +15,38 @@ // specific language governing permissions and limitations // under the License. - #include "metrics/PrometheusMetrics.h" +#include "SystemInfo.h" #include "cache/GpuCacheMgr.h" #include "server/Config.h" #include "utils/Log.h" -#include "SystemInfo.h" #include #include -namespace zilliz { namespace milvus { namespace server { ErrorCode PrometheusMetrics::Init() { try { - Config &config = Config::GetInstance(); + Config& config = Config::GetInstance(); Status s = config.GetMetricConfigEnableMonitor(startup_); - if (!s.ok()) return s.code(); - if (!startup_) return SERVER_SUCCESS; + if (!s.ok()) { + return s.code(); + } + if (!startup_) { + return SERVER_SUCCESS; + } // Following should be read from config file. std::string bind_address; s = config.GetMetricConfigPrometheusPort(bind_address); - if (!s.ok()) return s.code(); - const std::string uri = std::string("/tmp/metrics"); + if (!s.ok()) { + return s.code(); + } + + const std::string uri = std::string("/metrics"); const std::size_t num_threads = 2; // Init Exposer @@ -49,7 +54,7 @@ PrometheusMetrics::Init() { // Exposer Registry exposer_ptr_->RegisterCollectable(registry_); - } catch (std::exception &ex) { + } catch (std::exception& ex) { SERVER_LOG_ERROR << "Failed to connect prometheus server: " << std::string(ex.what()); return SERVER_UNEXPECTED_ERROR; } @@ -59,41 +64,53 @@ PrometheusMetrics::Init() { void PrometheusMetrics::CPUUsagePercentSet() { - if (!startup_) return; + if (!startup_) { + return; + } + double usage_percent = server::SystemInfo::GetInstance().CPUPercent(); CPU_usage_percent_.Set(usage_percent); } void PrometheusMetrics::RAMUsagePercentSet() { - if (!startup_) return; + if (!startup_) { + return; + } + double usage_percent = server::SystemInfo::GetInstance().MemoryPercent(); RAM_usage_percent_.Set(usage_percent); } void PrometheusMetrics::GPUPercentGaugeSet() { - if (!startup_) return; + if (!startup_) { + return; + } + int numDevice = server::SystemInfo::GetInstance().num_device(); std::vector used_total = server::SystemInfo::GetInstance().GPUMemoryTotal(); std::vector used_memory = server::SystemInfo::GetInstance().GPUMemoryUsed(); for (int i = 0; i < numDevice; ++i) { - prometheus::Gauge &GPU_percent = GPU_percent_.Add({{"DeviceNum", std::to_string(i)}}); - double percent = (double) used_memory[i] / (double) used_total[i]; + prometheus::Gauge& GPU_percent = GPU_percent_.Add({{"DeviceNum", std::to_string(i)}}); + double percent = (double)used_memory[i] / (double)used_total[i]; GPU_percent.Set(percent * 100); } } void PrometheusMetrics::GPUMemoryUsageGaugeSet() { - if (!startup_) return; + if (!startup_) { + return; + } + std::vector values = server::SystemInfo::GetInstance().GPUMemoryUsed(); constexpr uint64_t MtoB = 1024 * 1024; int numDevice = server::SystemInfo::GetInstance().num_device(); for (int i = 0; i < numDevice; ++i) { - prometheus::Gauge &GPU_memory = GPU_memory_usage_.Add({{"DeviceNum", std::to_string(i)}}); + prometheus::Gauge& GPU_memory = GPU_memory_usage_.Add({{"DeviceNum", std::to_string(i)}}); GPU_memory.Set(values[i] / MtoB); } } @@ -101,7 +118,9 @@ PrometheusMetrics::GPUMemoryUsageGaugeSet() { void PrometheusMetrics::AddVectorsPerSecondGaugeSet(int num_vector, int dim, double time) { // MB/s - if (!startup_) return; + if (!startup_) { + return; + } int64_t MtoB = 1024 * 1024; int64_t size = num_vector * dim * 4; @@ -110,7 +129,10 @@ PrometheusMetrics::AddVectorsPerSecondGaugeSet(int num_vector, int dim, double t void PrometheusMetrics::QueryIndexTypePerSecondSet(std::string type, double value) { - if (!startup_) return; + if (!startup_) { + return; + } + if (type == "IVF") { query_index_IVF_type_per_second_gauge_.Set(value); } else if (type == "IDMap") { @@ -120,19 +142,27 @@ PrometheusMetrics::QueryIndexTypePerSecondSet(std::string type, double value) { void PrometheusMetrics::ConnectionGaugeIncrement() { - if (!startup_) return; + if (!startup_) { + return; + } + connection_gauge_.Increment(); } void PrometheusMetrics::ConnectionGaugeDecrement() { - if (!startup_) return; + if (!startup_) { + return; + } + connection_gauge_.Decrement(); } void PrometheusMetrics::OctetsSet() { - if (!startup_) return; + if (!startup_) { + return; + } // get old stats and reset them uint64_t old_inoctets = SystemInfo::GetInstance().get_inoctets(); @@ -148,61 +178,66 @@ PrometheusMetrics::OctetsSet() { auto now_time = std::chrono::system_clock::now(); auto total_microsecond = METRICS_MICROSECONDS(old_time, now_time); auto total_second = total_microsecond * micro_to_second; - if (total_second == 0) return; + if (total_second == 0) { + return; + } + inoctets_gauge_.Set((in_and_out_octets.first - old_inoctets) / total_second); outoctets_gauge_.Set((in_and_out_octets.second - old_outoctets) / total_second); } void PrometheusMetrics::CPUCoreUsagePercentSet() { - if (!startup_) + if (!startup_) { return; + } std::vector cpu_core_percent = server::SystemInfo::GetInstance().CPUCorePercent(); for (int i = 0; i < cpu_core_percent.size(); ++i) { - prometheus::Gauge &core_percent = CPU_.Add({{"CPU", std::to_string(i)}}); + prometheus::Gauge& core_percent = CPU_.Add({{"CPU", std::to_string(i)}}); core_percent.Set(cpu_core_percent[i]); } } void PrometheusMetrics::GPUTemperature() { - if (!startup_) + if (!startup_) { return; + } std::vector GPU_temperatures = server::SystemInfo::GetInstance().GPUTemperature(); for (int i = 0; i < GPU_temperatures.size(); ++i) { - prometheus::Gauge &gpu_temp = GPU_temperature_.Add({{"GPU", std::to_string(i)}}); + prometheus::Gauge& gpu_temp = GPU_temperature_.Add({{"GPU", std::to_string(i)}}); gpu_temp.Set(GPU_temperatures[i]); } } void PrometheusMetrics::CPUTemperature() { - if (!startup_) + if (!startup_) { return; + } std::vector CPU_temperatures = server::SystemInfo::GetInstance().CPUTemperature(); for (int i = 0; i < CPU_temperatures.size(); ++i) { - prometheus::Gauge &cpu_temp = CPU_temperature_.Add({{"CPU", std::to_string(i)}}); + prometheus::Gauge& cpu_temp = CPU_temperature_.Add({{"CPU", std::to_string(i)}}); cpu_temp.Set(CPU_temperatures[i]); } } void PrometheusMetrics::GpuCacheUsageGaugeSet() { -// std::vector gpu_ids = {0}; -// for(auto i = 0; i < gpu_ids.size(); ++i) { -// uint64_t cache_usage = cache::GpuCacheMgr::GetInstance(gpu_ids[i])->CacheUsage(); -// uint64_t cache_capacity = cache::GpuCacheMgr::GetInstance(gpu_ids[i])->CacheCapacity(); -// prometheus::Gauge &gpu_cache = gpu_cache_usage_.Add({{"GPU_Cache", std::to_string(i)}}); -// gpu_cache.Set(cache_usage * 100 / cache_capacity); -// } + // std::vector gpu_ids = {0}; + // for(auto i = 0; i < gpu_ids.size(); ++i) { + // uint64_t cache_usage = cache::GpuCacheMgr::GetInstance(gpu_ids[i])->CacheUsage(); + // uint64_t cache_capacity = cache::GpuCacheMgr::GetInstance(gpu_ids[i])->CacheCapacity(); + // prometheus::Gauge &gpu_cache = gpu_cache_usage_.Add({{"GPU_Cache", std::to_string(i)}}); + // gpu_cache.Set(cache_usage * 100 / cache_capacity); + // } } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/metrics/PrometheusMetrics.h b/cpp/src/metrics/PrometheusMetrics.h index ce45c5e71155442248be00798e298eb8cbdefa11..ef60f9a231f97b8dcbb97e984fbaf3920cb25b50 100644 --- a/cpp/src/metrics/PrometheusMetrics.h +++ b/cpp/src/metrics/PrometheusMetrics.h @@ -17,33 +17,33 @@ #pragma once -#include -#include -#include #include +#include #include +#include #include +#include -#include "utils/Error.h" #include "MetricBase.h" +#include "utils/Error.h" #define METRICS_NOW_TIME std::chrono::system_clock::now() //#define server::Metrics::GetInstance() server::GetInstance() -#define METRICS_MICROSECONDS(a, b) (std::chrono::duration_cast (b-a)).count(); +#define METRICS_MICROSECONDS(a, b) (std::chrono::duration_cast(b - a)).count(); -namespace zilliz { namespace milvus { namespace server { class PrometheusMetrics : public MetricsBase { public: - static PrometheusMetrics & + static PrometheusMetrics& GetInstance() { static PrometheusMetrics instance; return instance; } - ErrorCode Init(); + ErrorCode + Init(); private: std::shared_ptr exposer_ptr_; @@ -51,163 +51,191 @@ class PrometheusMetrics : public MetricsBase { bool startup_ = false; public: - void SetStartup(bool startup) { + void + SetStartup(bool startup) { startup_ = startup; } - void AddVectorsSuccessTotalIncrement(double value = 1.0) override { + void + AddVectorsSuccessTotalIncrement(double value = 1.0) override { if (startup_) { add_vectors_success_total_.Increment(value); } } - void AddVectorsFailTotalIncrement(double value = 1.0) override { + void + AddVectorsFailTotalIncrement(double value = 1.0) override { if (startup_) { add_vectors_fail_total_.Increment(value); } } - void AddVectorsDurationHistogramOberve(double value) override { + void + AddVectorsDurationHistogramOberve(double value) override { if (startup_) { add_vectors_duration_histogram_.Observe(value); } } - void RawFileSizeHistogramObserve(double value) override { + void + RawFileSizeHistogramObserve(double value) override { if (startup_) { raw_files_size_histogram_.Observe(value); } } - void IndexFileSizeHistogramObserve(double value) override { + void + IndexFileSizeHistogramObserve(double value) override { if (startup_) { index_files_size_histogram_.Observe(value); } } - void BuildIndexDurationSecondsHistogramObserve(double value) override { + void + BuildIndexDurationSecondsHistogramObserve(double value) override { if (startup_) { build_index_duration_seconds_histogram_.Observe(value); } } - void CpuCacheUsageGaugeSet(double value) override { + void + CpuCacheUsageGaugeSet(double value) override { if (startup_) { cpu_cache_usage_gauge_.Set(value); } } - void GpuCacheUsageGaugeSet() override; + void + GpuCacheUsageGaugeSet() override; - void MetaAccessTotalIncrement(double value = 1) override { + void + MetaAccessTotalIncrement(double value = 1) override { if (startup_) { meta_access_total_.Increment(value); } } - void MetaAccessDurationSecondsHistogramObserve(double value) override { + void + MetaAccessDurationSecondsHistogramObserve(double value) override { if (startup_) { meta_access_duration_seconds_histogram_.Observe(value); } } - void FaissDiskLoadDurationSecondsHistogramObserve(double value) override { + void + FaissDiskLoadDurationSecondsHistogramObserve(double value) override { if (startup_) { faiss_disk_load_duration_seconds_histogram_.Observe(value); } } - void FaissDiskLoadSizeBytesHistogramObserve(double value) override { + void + FaissDiskLoadSizeBytesHistogramObserve(double value) override { if (startup_) { faiss_disk_load_size_bytes_histogram_.Observe(value); } } - void FaissDiskLoadIOSpeedGaugeSet(double value) override { + void + FaissDiskLoadIOSpeedGaugeSet(double value) override { if (startup_) { faiss_disk_load_IO_speed_gauge_.Set(value); } } - void CacheAccessTotalIncrement(double value = 1) override { + void + CacheAccessTotalIncrement(double value = 1) override { if (startup_) { cache_access_total_.Increment(value); } } - void MemTableMergeDurationSecondsHistogramObserve(double value) override { + void + MemTableMergeDurationSecondsHistogramObserve(double value) override { if (startup_) { mem_table_merge_duration_seconds_histogram_.Observe(value); } } - void SearchIndexDataDurationSecondsHistogramObserve(double value) override { + void + SearchIndexDataDurationSecondsHistogramObserve(double value) override { if (startup_) { search_index_data_duration_seconds_histogram_.Observe(value); } } - void SearchRawDataDurationSecondsHistogramObserve(double value) override { + void + SearchRawDataDurationSecondsHistogramObserve(double value) override { if (startup_) { search_raw_data_duration_seconds_histogram_.Observe(value); } } - void IndexFileSizeTotalIncrement(double value = 1) override { + void + IndexFileSizeTotalIncrement(double value = 1) override { if (startup_) { index_file_size_total_.Increment(value); } } - void RawFileSizeTotalIncrement(double value = 1) override { + void + RawFileSizeTotalIncrement(double value = 1) override { if (startup_) { raw_file_size_total_.Increment(value); } } - void IndexFileSizeGaugeSet(double value) override { + void + IndexFileSizeGaugeSet(double value) override { if (startup_) { index_file_size_gauge_.Set(value); } } - void RawFileSizeGaugeSet(double value) override { + void + RawFileSizeGaugeSet(double value) override { if (startup_) { raw_file_size_gauge_.Set(value); } } - void QueryResponseSummaryObserve(double value) override { + void + QueryResponseSummaryObserve(double value) override { if (startup_) { query_response_summary_.Observe(value); } } - void DiskStoreIOSpeedGaugeSet(double value) override { + void + DiskStoreIOSpeedGaugeSet(double value) override { if (startup_) { disk_store_IO_speed_gauge_.Set(value); } } - void DataFileSizeGaugeSet(double value) override { + void + DataFileSizeGaugeSet(double value) override { if (startup_) { data_file_size_gauge_.Set(value); } } - void AddVectorsSuccessGaugeSet(double value) override { + void + AddVectorsSuccessGaugeSet(double value) override { if (startup_) { add_vectors_success_gauge_.Set(value); } } - void AddVectorsFailGaugeSet(double value) override { + void + AddVectorsFailGaugeSet(double value) override { if (startup_) { add_vectors_fail_gauge_.Set(value); } } - void QueryVectorResponseSummaryObserve(double value, int count = 1) override { + void + QueryVectorResponseSummaryObserve(double value, int count = 1) override { if (startup_) { for (int i = 0; i < count; ++i) { query_vector_response_summary_.Observe(value); @@ -215,412 +243,413 @@ class PrometheusMetrics : public MetricsBase { } } - void QueryVectorResponsePerSecondGaugeSet(double value) override { + void + QueryVectorResponsePerSecondGaugeSet(double value) override { if (startup_) { query_vector_response_per_second_gauge_.Set(value); } } - void CPUUsagePercentSet() override; - void CPUCoreUsagePercentSet() override; + void + CPUUsagePercentSet() override; + void + CPUCoreUsagePercentSet() override; - void RAMUsagePercentSet() override; + void + RAMUsagePercentSet() override; - void QueryResponsePerSecondGaugeSet(double value) override { + void + QueryResponsePerSecondGaugeSet(double value) override { if (startup_) { query_response_per_second_gauge.Set(value); } } - void GPUPercentGaugeSet() override; - void GPUMemoryUsageGaugeSet() override; - void AddVectorsPerSecondGaugeSet(int num_vector, int dim, double time) override; - void QueryIndexTypePerSecondSet(std::string type, double value) override; - void ConnectionGaugeIncrement() override; - void ConnectionGaugeDecrement() override; - - void KeepingAliveCounterIncrement(double value = 1) override { + void + GPUPercentGaugeSet() override; + void + GPUMemoryUsageGaugeSet() override; + void + AddVectorsPerSecondGaugeSet(int num_vector, int dim, double time) override; + void + QueryIndexTypePerSecondSet(std::string type, double value) override; + void + ConnectionGaugeIncrement() override; + void + ConnectionGaugeDecrement() override; + + void + KeepingAliveCounterIncrement(double value = 1) override { if (startup_) { keeping_alive_counter_.Increment(value); } } - void OctetsSet() override; + void + OctetsSet() override; - void GPUTemperature() override; - void CPUTemperature() override; + void + GPUTemperature() override; + void + CPUTemperature() override; - std::shared_ptr &exposer_ptr() { + std::shared_ptr& + exposer_ptr() { return exposer_ptr_; } -// prometheus::Exposer& exposer() { return exposer_;} - std::shared_ptr ®istry_ptr() { + // prometheus::Exposer& exposer() { return exposer_;} + std::shared_ptr& + registry_ptr() { return registry_; } // ..... private: ////all from db_connection.cpp -// prometheus::Family &connect_request_ = prometheus::BuildCounter() -// .Name("connection_total") -// .Help("total number of connection has been made") -// .Register(*registry_); -// prometheus::Counter &connection_total_ = connect_request_.Add({}); + // prometheus::Family &connect_request_ = prometheus::BuildCounter() + // .Name("connection_total") + // .Help("total number of connection has been made") + // .Register(*registry_); + // prometheus::Counter &connection_total_ = connect_request_.Add({}); ////all from DBImpl.cpp using BucketBoundaries = std::vector; - //record add_group request - prometheus::Family &add_group_request_ = prometheus::BuildCounter() - .Name("add_group_request_total") - .Help("the number of add_group request") - .Register(*registry_); - - prometheus::Counter &add_group_success_total_ = add_group_request_.Add({{"outcome", "success"}}); - prometheus::Counter &add_group_fail_total_ = add_group_request_.Add({{"outcome", "fail"}}); - - //record get_group request - prometheus::Family &get_group_request_ = prometheus::BuildCounter() - .Name("get_group_request_total") - .Help("the number of get_group request") - .Register(*registry_); - - prometheus::Counter &get_group_success_total_ = get_group_request_.Add({{"outcome", "success"}}); - prometheus::Counter &get_group_fail_total_ = get_group_request_.Add({{"outcome", "fail"}}); - - //record has_group request - prometheus::Family &has_group_request_ = prometheus::BuildCounter() - .Name("has_group_request_total") - .Help("the number of has_group request") - .Register(*registry_); - - prometheus::Counter &has_group_success_total_ = has_group_request_.Add({{"outcome", "success"}}); - prometheus::Counter &has_group_fail_total_ = has_group_request_.Add({{"outcome", "fail"}}); - - //record get_group_files - prometheus::Family &get_group_files_request_ = prometheus::BuildCounter() - .Name("get_group_files_request_total") - .Help("the number of get_group_files request") - .Register(*registry_); - - prometheus::Counter &get_group_files_success_total_ = get_group_files_request_.Add({{"outcome", "success"}}); - prometheus::Counter &get_group_files_fail_total_ = get_group_files_request_.Add({{"outcome", "fail"}}); - - //record add_vectors count and average time - //need to be considered - prometheus::Family &add_vectors_request_ = prometheus::BuildCounter() - .Name("add_vectors_request_total") - .Help("the number of vectors added") - .Register(*registry_); - prometheus::Counter &add_vectors_success_total_ = add_vectors_request_.Add({{"outcome", "success"}}); - prometheus::Counter &add_vectors_fail_total_ = add_vectors_request_.Add({{"outcome", "fail"}}); - - prometheus::Family &add_vectors_duration_seconds_ = prometheus::BuildHistogram() - .Name("add_vector_duration_microseconds") - .Help("average time of adding every vector") - .Register(*registry_); - prometheus::Histogram &add_vectors_duration_histogram_ = + // record add_group request + prometheus::Family& add_group_request_ = prometheus::BuildCounter() + .Name("add_group_request_total") + .Help("the number of add_group request") + .Register(*registry_); + + prometheus::Counter& add_group_success_total_ = add_group_request_.Add({{"outcome", "success"}}); + prometheus::Counter& add_group_fail_total_ = add_group_request_.Add({{"outcome", "fail"}}); + + // record get_group request + prometheus::Family& get_group_request_ = prometheus::BuildCounter() + .Name("get_group_request_total") + .Help("the number of get_group request") + .Register(*registry_); + + prometheus::Counter& get_group_success_total_ = get_group_request_.Add({{"outcome", "success"}}); + prometheus::Counter& get_group_fail_total_ = get_group_request_.Add({{"outcome", "fail"}}); + + // record has_group request + prometheus::Family& has_group_request_ = prometheus::BuildCounter() + .Name("has_group_request_total") + .Help("the number of has_group request") + .Register(*registry_); + + prometheus::Counter& has_group_success_total_ = has_group_request_.Add({{"outcome", "success"}}); + prometheus::Counter& has_group_fail_total_ = has_group_request_.Add({{"outcome", "fail"}}); + + // record get_group_files + prometheus::Family& get_group_files_request_ = + prometheus::BuildCounter() + .Name("get_group_files_request_total") + .Help("the number of get_group_files request") + .Register(*registry_); + + prometheus::Counter& get_group_files_success_total_ = get_group_files_request_.Add({{"outcome", "success"}}); + prometheus::Counter& get_group_files_fail_total_ = get_group_files_request_.Add({{"outcome", "fail"}}); + + // record add_vectors count and average time + // need to be considered + prometheus::Family& add_vectors_request_ = prometheus::BuildCounter() + .Name("add_vectors_request_total") + .Help("the number of vectors added") + .Register(*registry_); + prometheus::Counter& add_vectors_success_total_ = add_vectors_request_.Add({{"outcome", "success"}}); + prometheus::Counter& add_vectors_fail_total_ = add_vectors_request_.Add({{"outcome", "fail"}}); + + prometheus::Family& add_vectors_duration_seconds_ = + prometheus::BuildHistogram() + .Name("add_vector_duration_microseconds") + .Help("average time of adding every vector") + .Register(*registry_); + prometheus::Histogram& add_vectors_duration_histogram_ = add_vectors_duration_seconds_.Add({}, BucketBoundaries{0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.08, 0.1, 0.5, 1}); - //record search count and average time - prometheus::Family &search_request_ = prometheus::BuildCounter() - .Name("search_request_total") - .Help("the number of search request") - .Register(*registry_); - prometheus::Counter &search_success_total_ = search_request_.Add({{"outcome", "success"}}); - prometheus::Counter &search_fail_total_ = search_request_.Add({{"outcome", "fail"}}); - - prometheus::Family &search_request_duration_seconds_ = prometheus::BuildHistogram() - .Name("search_request_duration_microsecond") - .Help("histogram of processing time for each search") - .Register(*registry_); - prometheus::Histogram - &search_duration_histogram_ = search_request_duration_seconds_.Add({}, BucketBoundaries{0.1, 1.0, 10.0}); - - //record raw_files size histogram - prometheus::Family &raw_files_size_ = prometheus::BuildHistogram() - .Name("search_raw_files_bytes") - .Help("histogram of raw files size by bytes") - .Register(*registry_); - prometheus::Histogram - &raw_files_size_histogram_ = raw_files_size_.Add({}, BucketBoundaries{1e9, 2e9, 4e9, 6e9, 8e9, 1e10}); - - //record index_files size histogram - prometheus::Family &index_files_size_ = prometheus::BuildHistogram() - .Name("search_index_files_bytes") - .Help("histogram of index files size by bytes") - .Register(*registry_); - prometheus::Histogram - &index_files_size_histogram_ = index_files_size_.Add({}, BucketBoundaries{1e9, 2e9, 4e9, 6e9, 8e9, 1e10}); - - //record index and raw files size counter - prometheus::Family &file_size_total_ = prometheus::BuildCounter() - .Name("search_file_size_total") - .Help("searched index and raw file size") - .Register(*registry_); - prometheus::Counter &index_file_size_total_ = file_size_total_.Add({{"type", "index"}}); - prometheus::Counter &raw_file_size_total_ = file_size_total_.Add({{"type", "raw"}}); - - //record index and raw files size counter - prometheus::Family &file_size_gauge_ = prometheus::BuildGauge() - .Name("search_file_size_gauge") - .Help("searched current index and raw file size") - .Register(*registry_); - prometheus::Gauge &index_file_size_gauge_ = file_size_gauge_.Add({{"type", "index"}}); - prometheus::Gauge &raw_file_size_gauge_ = file_size_gauge_.Add({{"type", "raw"}}); - - //record processing time for building index - prometheus::Family &build_index_duration_seconds_ = prometheus::BuildHistogram() - .Name("build_index_duration_microseconds") - .Help("histogram of processing time for building index") - .Register(*registry_); - prometheus::Histogram &build_index_duration_seconds_histogram_ = + // record search count and average time + prometheus::Family& search_request_ = prometheus::BuildCounter() + .Name("search_request_total") + .Help("the number of search request") + .Register(*registry_); + prometheus::Counter& search_success_total_ = search_request_.Add({{"outcome", "success"}}); + prometheus::Counter& search_fail_total_ = search_request_.Add({{"outcome", "fail"}}); + + prometheus::Family& search_request_duration_seconds_ = + prometheus::BuildHistogram() + .Name("search_request_duration_microsecond") + .Help("histogram of processing time for each search") + .Register(*registry_); + prometheus::Histogram& search_duration_histogram_ = + search_request_duration_seconds_.Add({}, BucketBoundaries{0.1, 1.0, 10.0}); + + // record raw_files size histogram + prometheus::Family& raw_files_size_ = prometheus::BuildHistogram() + .Name("search_raw_files_bytes") + .Help("histogram of raw files size by bytes") + .Register(*registry_); + prometheus::Histogram& raw_files_size_histogram_ = + raw_files_size_.Add({}, BucketBoundaries{1e9, 2e9, 4e9, 6e9, 8e9, 1e10}); + + // record index_files size histogram + prometheus::Family& index_files_size_ = prometheus::BuildHistogram() + .Name("search_index_files_bytes") + .Help("histogram of index files size by bytes") + .Register(*registry_); + prometheus::Histogram& index_files_size_histogram_ = + index_files_size_.Add({}, BucketBoundaries{1e9, 2e9, 4e9, 6e9, 8e9, 1e10}); + + // record index and raw files size counter + prometheus::Family& file_size_total_ = prometheus::BuildCounter() + .Name("search_file_size_total") + .Help("searched index and raw file size") + .Register(*registry_); + prometheus::Counter& index_file_size_total_ = file_size_total_.Add({{"type", "index"}}); + prometheus::Counter& raw_file_size_total_ = file_size_total_.Add({{"type", "raw"}}); + + // record index and raw files size counter + prometheus::Family& file_size_gauge_ = prometheus::BuildGauge() + .Name("search_file_size_gauge") + .Help("searched current index and raw file size") + .Register(*registry_); + prometheus::Gauge& index_file_size_gauge_ = file_size_gauge_.Add({{"type", "index"}}); + prometheus::Gauge& raw_file_size_gauge_ = file_size_gauge_.Add({{"type", "raw"}}); + + // record processing time for building index + prometheus::Family& build_index_duration_seconds_ = + prometheus::BuildHistogram() + .Name("build_index_duration_microseconds") + .Help("histogram of processing time for building index") + .Register(*registry_); + prometheus::Histogram& build_index_duration_seconds_histogram_ = build_index_duration_seconds_.Add({}, BucketBoundaries{5e5, 2e6, 4e6, 6e6, 8e6, 1e7}); - //record processing time for all building index - prometheus::Family &all_build_index_duration_seconds_ = prometheus::BuildHistogram() - .Name("all_build_index_duration_microseconds") - .Help("histogram of processing time for building index") - .Register(*registry_); - prometheus::Histogram &all_build_index_duration_seconds_histogram_ = + // record processing time for all building index + prometheus::Family& all_build_index_duration_seconds_ = + prometheus::BuildHistogram() + .Name("all_build_index_duration_microseconds") + .Help("histogram of processing time for building index") + .Register(*registry_); + prometheus::Histogram& all_build_index_duration_seconds_histogram_ = all_build_index_duration_seconds_.Add({}, BucketBoundaries{2e6, 4e6, 6e6, 8e6, 1e7}); - //record duration of merging mem table - prometheus::Family &mem_table_merge_duration_seconds_ = prometheus::BuildHistogram() - .Name("mem_table_merge_duration_microseconds") - .Help("histogram of processing time for merging mem tables") - .Register(*registry_); - prometheus::Histogram &mem_table_merge_duration_seconds_histogram_ = + // record duration of merging mem table + prometheus::Family& mem_table_merge_duration_seconds_ = + prometheus::BuildHistogram() + .Name("mem_table_merge_duration_microseconds") + .Help("histogram of processing time for merging mem tables") + .Register(*registry_); + prometheus::Histogram& mem_table_merge_duration_seconds_histogram_ = mem_table_merge_duration_seconds_.Add({}, BucketBoundaries{5e4, 1e5, 2e5, 4e5, 6e5, 8e5, 1e6}); - //record search index and raw data duration - prometheus::Family &search_data_duration_seconds_ = prometheus::BuildHistogram() - .Name("search_data_duration_microseconds") - .Help("histograms of processing time for search index and raw data") - .Register(*registry_); - prometheus::Histogram &search_index_data_duration_seconds_histogram_ = + // record search index and raw data duration + prometheus::Family& search_data_duration_seconds_ = + prometheus::BuildHistogram() + .Name("search_data_duration_microseconds") + .Help("histograms of processing time for search index and raw data") + .Register(*registry_); + prometheus::Histogram& search_index_data_duration_seconds_histogram_ = search_data_duration_seconds_.Add({{"type", "index"}}, BucketBoundaries{1e5, 2e5, 4e5, 6e5, 8e5}); - prometheus::Histogram &search_raw_data_duration_seconds_histogram_ = + prometheus::Histogram& search_raw_data_duration_seconds_histogram_ = search_data_duration_seconds_.Add({{"type", "raw"}}, BucketBoundaries{1e5, 2e5, 4e5, 6e5, 8e5}); - ////all form Cache.cpp - //record cache usage, when insert/erase/clear/free - + // record cache usage, when insert/erase/clear/free ////all from Meta.cpp - //record meta visit count and time -// prometheus::Family &meta_visit_ = prometheus::BuildCounter() -// .Name("meta_visit_total") -// .Help("the number of accessing Meta") -// .Register(*registry_); -// prometheus::Counter &meta_visit_total_ = meta_visit_.Add({{}}); -// -// prometheus::Family &meta_visit_duration_seconds_ = prometheus::BuildHistogram() -// .Name("meta_visit_duration_seconds") -// .Help("histogram of processing time to get data from mata") -// .Register(*registry_); -// prometheus::Histogram &meta_visit_duration_seconds_histogram_ = -// meta_visit_duration_seconds_.Add({{}}, BucketBoundaries{0.1, 1.0, 10.0}); - + // record meta visit count and time + // prometheus::Family &meta_visit_ = prometheus::BuildCounter() + // .Name("meta_visit_total") + // .Help("the number of accessing Meta") + // .Register(*registry_); + // prometheus::Counter &meta_visit_total_ = meta_visit_.Add({{}}); + // + // prometheus::Family &meta_visit_duration_seconds_ = prometheus::BuildHistogram() + // .Name("meta_visit_duration_seconds") + // .Help("histogram of processing time to get data from mata") + // .Register(*registry_); + // prometheus::Histogram &meta_visit_duration_seconds_histogram_ = + // meta_visit_duration_seconds_.Add({{}}, BucketBoundaries{0.1, 1.0, 10.0}); ////all from MemManager.cpp - //record memory usage percent - prometheus::Family &mem_usage_percent_ = prometheus::BuildGauge() - .Name("memory_usage_percent") - .Help("memory usage percent") - .Register(*registry_); - prometheus::Gauge &mem_usage_percent_gauge_ = mem_usage_percent_.Add({}); - - //record memory usage toal - prometheus::Family &mem_usage_total_ = prometheus::BuildGauge() - .Name("memory_usage_total") - .Help("memory usage total") - .Register(*registry_); - prometheus::Gauge &mem_usage_total_gauge_ = mem_usage_total_.Add({}); + // record memory usage percent + prometheus::Family& mem_usage_percent_ = + prometheus::BuildGauge().Name("memory_usage_percent").Help("memory usage percent").Register(*registry_); + prometheus::Gauge& mem_usage_percent_gauge_ = mem_usage_percent_.Add({}); + + // record memory usage toal + prometheus::Family& mem_usage_total_ = + prometheus::BuildGauge().Name("memory_usage_total").Help("memory usage total").Register(*registry_); + prometheus::Gauge& mem_usage_total_gauge_ = mem_usage_total_.Add({}); ////all from DBMetaImpl.cpp - //record meta access count - prometheus::Family &meta_access_ = prometheus::BuildCounter() - .Name("meta_access_total") - .Help("the number of meta accessing") - .Register(*registry_); - prometheus::Counter &meta_access_total_ = meta_access_.Add({}); - - //record meta access duration - prometheus::Family &meta_access_duration_seconds_ = prometheus::BuildHistogram() - .Name("meta_access_duration_microseconds") - .Help("histogram of processing time for accessing mata") - .Register(*registry_); - prometheus::Histogram &meta_access_duration_seconds_histogram_ = + // record meta access count + prometheus::Family& meta_access_ = + prometheus::BuildCounter().Name("meta_access_total").Help("the number of meta accessing").Register(*registry_); + prometheus::Counter& meta_access_total_ = meta_access_.Add({}); + + // record meta access duration + prometheus::Family& meta_access_duration_seconds_ = + prometheus::BuildHistogram() + .Name("meta_access_duration_microseconds") + .Help("histogram of processing time for accessing mata") + .Register(*registry_); + prometheus::Histogram& meta_access_duration_seconds_histogram_ = meta_access_duration_seconds_.Add({}, BucketBoundaries{100, 300, 500, 700, 900, 2000, 4000, 6000, 8000, 20000}); ////all from FaissExecutionEngine.cpp - //record data loading from disk count, size, duration, IO speed - prometheus::Family &disk_load_duration_second_ = prometheus::BuildHistogram() - .Name("disk_load_duration_microseconds") - .Help("Histogram of processing time for loading data from disk") - .Register(*registry_); - prometheus::Histogram &faiss_disk_load_duration_seconds_histogram_ = + // record data loading from disk count, size, duration, IO speed + prometheus::Family& disk_load_duration_second_ = + prometheus::BuildHistogram() + .Name("disk_load_duration_microseconds") + .Help("Histogram of processing time for loading data from disk") + .Register(*registry_); + prometheus::Histogram& faiss_disk_load_duration_seconds_histogram_ = disk_load_duration_second_.Add({{"DB", "Faiss"}}, BucketBoundaries{2e5, 4e5, 6e5, 8e5}); - prometheus::Family &disk_load_size_bytes_ = prometheus::BuildHistogram() - .Name("disk_load_size_bytes") - .Help("Histogram of data size by bytes for loading data from disk") - .Register(*registry_); - prometheus::Histogram &faiss_disk_load_size_bytes_histogram_ = + prometheus::Family& disk_load_size_bytes_ = + prometheus::BuildHistogram() + .Name("disk_load_size_bytes") + .Help("Histogram of data size by bytes for loading data from disk") + .Register(*registry_); + prometheus::Histogram& faiss_disk_load_size_bytes_histogram_ = disk_load_size_bytes_.Add({{"DB", "Faiss"}}, BucketBoundaries{1e9, 2e9, 4e9, 6e9, 8e9}); -// prometheus::Family &disk_load_IO_speed_ = prometheus::BuildHistogram() -// .Name("disk_load_IO_speed_byte_per_sec") -// .Help("Histogram of IO speed for loading data from disk") -// .Register(*registry_); -// prometheus::Histogram &faiss_disk_load_IO_speed_histogram_ = -// disk_load_IO_speed_.Add({{"DB","Faiss"}},BucketBoundaries{1000, 2000, 3000, 4000, 6000, 8000}); + // prometheus::Family &disk_load_IO_speed_ = prometheus::BuildHistogram() + // .Name("disk_load_IO_speed_byte_per_sec") + // .Help("Histogram of IO speed for loading data from disk") + // .Register(*registry_); + // prometheus::Histogram &faiss_disk_load_IO_speed_histogram_ = + // disk_load_IO_speed_.Add({{"DB","Faiss"}},BucketBoundaries{1000, 2000, 3000, 4000, 6000, 8000}); - prometheus::Family &faiss_disk_load_IO_speed_ = prometheus::BuildGauge() - .Name("disk_load_IO_speed_byte_per_microsec") - .Help("disk IO speed ") - .Register(*registry_); - prometheus::Gauge &faiss_disk_load_IO_speed_gauge_ = faiss_disk_load_IO_speed_.Add({{"DB", "Faiss"}}); + prometheus::Family& faiss_disk_load_IO_speed_ = prometheus::BuildGauge() + .Name("disk_load_IO_speed_byte_per_microsec") + .Help("disk IO speed ") + .Register(*registry_); + prometheus::Gauge& faiss_disk_load_IO_speed_gauge_ = faiss_disk_load_IO_speed_.Add({{"DB", "Faiss"}}); ////all from CacheMgr.cpp - //record cache access count - prometheus::Family &cache_access_ = prometheus::BuildCounter() - .Name("cache_access_total") - .Help("the count of accessing cache ") - .Register(*registry_); - prometheus::Counter &cache_access_total_ = cache_access_.Add({}); + // record cache access count + prometheus::Family& cache_access_ = prometheus::BuildCounter() + .Name("cache_access_total") + .Help("the count of accessing cache ") + .Register(*registry_); + prometheus::Counter& cache_access_total_ = cache_access_.Add({}); // record CPU cache usage and % - prometheus::Family &cpu_cache_usage_ = prometheus::BuildGauge() - .Name("cache_usage_bytes") - .Help("current cache usage by bytes") - .Register(*registry_); - prometheus::Gauge &cpu_cache_usage_gauge_ = cpu_cache_usage_.Add({}); - - //record GPU cache usage and % - prometheus::Family &gpu_cache_usage_ = prometheus::BuildGauge() - .Name("gpu_cache_usage_bytes") - .Help("current gpu cache usage by bytes") - .Register(*registry_); + prometheus::Family& cpu_cache_usage_ = + prometheus::BuildGauge().Name("cache_usage_bytes").Help("current cache usage by bytes").Register(*registry_); + prometheus::Gauge& cpu_cache_usage_gauge_ = cpu_cache_usage_.Add({}); + + // record GPU cache usage and % + prometheus::Family& gpu_cache_usage_ = prometheus::BuildGauge() + .Name("gpu_cache_usage_bytes") + .Help("current gpu cache usage by bytes") + .Register(*registry_); // record query response using Quantiles = std::vector; - prometheus::Family &query_response_ = prometheus::BuildSummary() - .Name("query_response_summary") - .Help("query response summary") - .Register(*registry_); - prometheus::Summary - &query_response_summary_ = query_response_.Add({}, Quantiles{{0.95, 0.00}, {0.9, 0.05}, {0.8, 0.1}}); - - prometheus::Family &query_vector_response_ = prometheus::BuildSummary() - .Name("query_vector_response_summary") - .Help("query each vector response summary") - .Register(*registry_); - prometheus::Summary &query_vector_response_summary_ = + prometheus::Family& query_response_ = + prometheus::BuildSummary().Name("query_response_summary").Help("query response summary").Register(*registry_); + prometheus::Summary& query_response_summary_ = + query_response_.Add({}, Quantiles{{0.95, 0.00}, {0.9, 0.05}, {0.8, 0.1}}); + + prometheus::Family& query_vector_response_ = prometheus::BuildSummary() + .Name("query_vector_response_summary") + .Help("query each vector response summary") + .Register(*registry_); + prometheus::Summary& query_vector_response_summary_ = query_vector_response_.Add({}, Quantiles{{0.95, 0.00}, {0.9, 0.05}, {0.8, 0.1}}); - prometheus::Family &query_vector_response_per_second_ = prometheus::BuildGauge() - .Name("query_vector_response_per_microsecond") - .Help("the number of vectors can be queried every second ") - .Register(*registry_); - prometheus::Gauge &query_vector_response_per_second_gauge_ = query_vector_response_per_second_.Add({}); - - prometheus::Family &query_response_per_second_ = prometheus::BuildGauge() - .Name("query_response_per_microsecond") - .Help("the number of queries can be processed every microsecond") - .Register(*registry_); - prometheus::Gauge &query_response_per_second_gauge = query_response_per_second_.Add({}); - - prometheus::Family &disk_store_IO_speed_ = prometheus::BuildGauge() - .Name("disk_store_IO_speed_bytes_per_microseconds") - .Help("disk_store_IO_speed") - .Register(*registry_); - prometheus::Gauge &disk_store_IO_speed_gauge_ = disk_store_IO_speed_.Add({}); - - prometheus::Family &data_file_size_ = prometheus::BuildGauge() - .Name("data_file_size_bytes") - .Help("data file size by bytes") - .Register(*registry_); - prometheus::Gauge &data_file_size_gauge_ = data_file_size_.Add({}); - - prometheus::Family &add_vectors_ = prometheus::BuildGauge() - .Name("add_vectors") - .Help("current added vectors") - .Register(*registry_); - prometheus::Gauge &add_vectors_success_gauge_ = add_vectors_.Add({{"outcome", "success"}}); - prometheus::Gauge &add_vectors_fail_gauge_ = add_vectors_.Add({{"outcome", "fail"}}); - - prometheus::Family &add_vectors_per_second_ = prometheus::BuildGauge() - .Name("add_vectors_throughput_per_microsecond") - .Help("add vectors throughput per microsecond") - .Register(*registry_); - prometheus::Gauge &add_vectors_per_second_gauge_ = add_vectors_per_second_.Add({}); - - prometheus::Family &CPU_ = prometheus::BuildGauge() - .Name("CPU_usage_percent") - .Help("CPU usage percent by this this process") - .Register(*registry_); - prometheus::Gauge &CPU_usage_percent_ = CPU_.Add({{"CPU", "avg"}}); - - prometheus::Family &RAM_ = prometheus::BuildGauge() - .Name("RAM_usage_percent") - .Help("RAM usage percent by this process") - .Register(*registry_); - prometheus::Gauge &RAM_usage_percent_ = RAM_.Add({}); - - //GPU Usage Percent - prometheus::Family &GPU_percent_ = prometheus::BuildGauge() - .Name("Gpu_usage_percent") - .Help("GPU_usage_percent ") - .Register(*registry_); - - //GPU Mempry used - prometheus::Family &GPU_memory_usage_ = prometheus::BuildGauge() - .Name("GPU_memory_usage_total") - .Help("GPU memory usage total ") - .Register(*registry_); - - prometheus::Family &query_index_type_per_second_ = prometheus::BuildGauge() - .Name("query_index_throughtout_per_microsecond") - .Help("query index throughtout per microsecond") - .Register(*registry_); - prometheus::Gauge - &query_index_IVF_type_per_second_gauge_ = query_index_type_per_second_.Add({{"IndexType", "IVF"}}); - prometheus::Gauge - &query_index_IDMAP_type_per_second_gauge_ = query_index_type_per_second_.Add({{"IndexType", "IDMAP"}}); - - prometheus::Family &connection_ = prometheus::BuildGauge() - .Name("connection_number") - .Help("the number of connections") - .Register(*registry_); - prometheus::Gauge &connection_gauge_ = connection_.Add({}); - - prometheus::Family &keeping_alive_ = prometheus::BuildCounter() - .Name("keeping_alive_seconds_total") - .Help("total seconds of the serve alive") - .Register(*registry_); - prometheus::Counter &keeping_alive_counter_ = keeping_alive_.Add({}); - - prometheus::Family &octets_ = prometheus::BuildGauge() - .Name("octets_bytes_per_second") - .Help("octets bytes per second") - .Register(*registry_); - prometheus::Gauge &inoctets_gauge_ = octets_.Add({{"type", "inoctets"}}); - prometheus::Gauge &outoctets_gauge_ = octets_.Add({{"type", "outoctets"}}); - - prometheus::Family &GPU_temperature_ = prometheus::BuildGauge() - .Name("GPU_temperature") - .Help("GPU temperature") - .Register(*registry_); - - prometheus::Family &CPU_temperature_ = prometheus::BuildGauge() - .Name("CPU_temperature") - .Help("CPU temperature") - .Register(*registry_); + prometheus::Family& query_vector_response_per_second_ = + prometheus::BuildGauge() + .Name("query_vector_response_per_microsecond") + .Help("the number of vectors can be queried every second ") + .Register(*registry_); + prometheus::Gauge& query_vector_response_per_second_gauge_ = query_vector_response_per_second_.Add({}); + + prometheus::Family& query_response_per_second_ = + prometheus::BuildGauge() + .Name("query_response_per_microsecond") + .Help("the number of queries can be processed every microsecond") + .Register(*registry_); + prometheus::Gauge& query_response_per_second_gauge = query_response_per_second_.Add({}); + + prometheus::Family& disk_store_IO_speed_ = + prometheus::BuildGauge() + .Name("disk_store_IO_speed_bytes_per_microseconds") + .Help("disk_store_IO_speed") + .Register(*registry_); + prometheus::Gauge& disk_store_IO_speed_gauge_ = disk_store_IO_speed_.Add({}); + + prometheus::Family& data_file_size_ = + prometheus::BuildGauge().Name("data_file_size_bytes").Help("data file size by bytes").Register(*registry_); + prometheus::Gauge& data_file_size_gauge_ = data_file_size_.Add({}); + + prometheus::Family& add_vectors_ = + prometheus::BuildGauge().Name("add_vectors").Help("current added vectors").Register(*registry_); + prometheus::Gauge& add_vectors_success_gauge_ = add_vectors_.Add({{"outcome", "success"}}); + prometheus::Gauge& add_vectors_fail_gauge_ = add_vectors_.Add({{"outcome", "fail"}}); + + prometheus::Family& add_vectors_per_second_ = prometheus::BuildGauge() + .Name("add_vectors_throughput_per_microsecond") + .Help("add vectors throughput per microsecond") + .Register(*registry_); + prometheus::Gauge& add_vectors_per_second_gauge_ = add_vectors_per_second_.Add({}); + + prometheus::Family& CPU_ = prometheus::BuildGauge() + .Name("CPU_usage_percent") + .Help("CPU usage percent by this this process") + .Register(*registry_); + prometheus::Gauge& CPU_usage_percent_ = CPU_.Add({{"CPU", "avg"}}); + + prometheus::Family& RAM_ = prometheus::BuildGauge() + .Name("RAM_usage_percent") + .Help("RAM usage percent by this process") + .Register(*registry_); + prometheus::Gauge& RAM_usage_percent_ = RAM_.Add({}); + + // GPU Usage Percent + prometheus::Family& GPU_percent_ = + prometheus::BuildGauge().Name("Gpu_usage_percent").Help("GPU_usage_percent ").Register(*registry_); + + // GPU Mempry used + prometheus::Family& GPU_memory_usage_ = + prometheus::BuildGauge().Name("GPU_memory_usage_total").Help("GPU memory usage total ").Register(*registry_); + + prometheus::Family& query_index_type_per_second_ = + prometheus::BuildGauge() + .Name("query_index_throughtout_per_microsecond") + .Help("query index throughtout per microsecond") + .Register(*registry_); + prometheus::Gauge& query_index_IVF_type_per_second_gauge_ = + query_index_type_per_second_.Add({{"IndexType", "IVF"}}); + prometheus::Gauge& query_index_IDMAP_type_per_second_gauge_ = + query_index_type_per_second_.Add({{"IndexType", "IDMAP"}}); + + prometheus::Family& connection_ = + prometheus::BuildGauge().Name("connection_number").Help("the number of connections").Register(*registry_); + prometheus::Gauge& connection_gauge_ = connection_.Add({}); + + prometheus::Family& keeping_alive_ = prometheus::BuildCounter() + .Name("keeping_alive_seconds_total") + .Help("total seconds of the serve alive") + .Register(*registry_); + prometheus::Counter& keeping_alive_counter_ = keeping_alive_.Add({}); + + prometheus::Family& octets_ = + prometheus::BuildGauge().Name("octets_bytes_per_second").Help("octets bytes per second").Register(*registry_); + prometheus::Gauge& inoctets_gauge_ = octets_.Add({{"type", "inoctets"}}); + prometheus::Gauge& outoctets_gauge_ = octets_.Add({{"type", "outoctets"}}); + + prometheus::Family& GPU_temperature_ = + prometheus::BuildGauge().Name("GPU_temperature").Help("GPU temperature").Register(*registry_); + + prometheus::Family& CPU_temperature_ = + prometheus::BuildGauge().Name("CPU_temperature").Help("CPU temperature").Register(*registry_); }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/metrics/SystemInfo.cpp b/cpp/src/metrics/SystemInfo.cpp index 54d0c51943ec0866394ade8e74e76bd30cd0a825..154f7b0797e00426465204c7d945cb26cded68b3 100644 --- a/cpp/src/metrics/SystemInfo.cpp +++ b/cpp/src/metrics/SystemInfo.cpp @@ -15,29 +15,33 @@ // specific language governing permissions and limitations // under the License. - #include "metrics/SystemInfo.h" +#include "utils/Log.h" +#include +#include +#include +#include #include #include -#include #include -#include +#include #include #include -namespace zilliz { namespace milvus { namespace server { void SystemInfo::Init() { - if (initialized_) return; + if (initialized_) { + return; + } initialized_ = true; // initialize CPU information - FILE *file; + FILE* file; struct tms time_sample; char line[128]; last_cpu_ = times(&time_sample); @@ -45,8 +49,10 @@ SystemInfo::Init() { last_user_cpu_ = time_sample.tms_utime; file = fopen("/proc/cpuinfo", "r"); num_processors_ = 0; - while (fgets(line, 128, file) != NULL) { - if (strncmp(line, "processor", 9) == 0) num_processors_++; + while (fgets(line, 128, file) != nullptr) { + if (strncmp(line, "processor", 9) == 0) { + num_processors_++; + } if (strncmp(line, "physical", 8) == 0) { num_physical_processors_ = ParseLine(line); } @@ -54,20 +60,20 @@ SystemInfo::Init() { total_ram_ = GetPhysicalMemory(); fclose(file); - //initialize GPU information + // initialize GPU information nvmlReturn_t nvmlresult; nvmlresult = nvmlInit(); if (NVML_SUCCESS != nvmlresult) { - printf("System information initilization failed"); + SERVER_LOG_ERROR << "System information initilization failed"; return; } nvmlresult = nvmlDeviceGetCount(&num_device_); if (NVML_SUCCESS != nvmlresult) { - printf("Unable to get devidce number"); + SERVER_LOG_ERROR << "Unable to get devidce number"; return; } - //initialize network traffic information + // initialize network traffic information std::pair in_and_out_octets = Octets(); in_octets_ = in_and_out_octets.first; out_octets_ = in_and_out_octets.second; @@ -75,11 +81,13 @@ SystemInfo::Init() { } uint64_t -SystemInfo::ParseLine(char *line) { +SystemInfo::ParseLine(char* line) { // This assumes that a digit will be found and the line ends in " Kb". int i = strlen(line); - const char *p = line; - while (*p < '0' || *p > '9') p++; + const char* p = line; + while (*p < '0' || *p > '9') { + p++; + } line[i - 3] = '\0'; i = atoi(p); return static_cast(i); @@ -90,21 +98,21 @@ SystemInfo::GetPhysicalMemory() { struct sysinfo memInfo; sysinfo(&memInfo); uint64_t totalPhysMem = memInfo.totalram; - //Multiply in next statement to avoid int overflow on right hand side... + // Multiply in next statement to avoid int overflow on right hand side... totalPhysMem *= memInfo.mem_unit; return totalPhysMem; } uint64_t SystemInfo::GetProcessUsedMemory() { - //Note: this value is in KB! - FILE *file = fopen("/proc/self/status", "r"); + // Note: this value is in KB! + FILE* file = fopen("/proc/self/status", "r"); constexpr uint64_t line_length = 128; uint64_t result = -1; constexpr uint64_t KB_SIZE = 1024; char line[line_length]; - while (fgets(line, line_length, file) != NULL) { + while (fgets(line, line_length, file) != nullptr) { if (strncmp(line, "VmRSS:", 6) == 0) { result = ParseLine(line); break; @@ -117,8 +125,12 @@ SystemInfo::GetProcessUsedMemory() { double SystemInfo::MemoryPercent() { - if (!initialized_) Init(); - return (double) (GetProcessUsedMemory() * 100) / (double) total_ram_; + if (!initialized_) { + Init(); + } + + double mem_used = static_cast(GetProcessUsedMemory() * 100); + return mem_used / static_cast(total_ram_); } std::vector @@ -139,11 +151,11 @@ SystemInfo::CPUCorePercent() { } std::vector -SystemInfo::getTotalCpuTime(std::vector &work_time_array) { +SystemInfo::getTotalCpuTime(std::vector& work_time_array) { std::vector total_time_array; - FILE *file = fopen("/proc/stat", "r"); + FILE* file = fopen("/proc/stat", "r"); if (file == NULL) { - perror("Could not open stat file"); + SERVER_LOG_ERROR << "Could not open stat file"; return total_time_array; } @@ -152,16 +164,15 @@ SystemInfo::getTotalCpuTime(std::vector &work_time_array) { for (int i = 0; i < num_processors_; i++) { char buffer[1024]; - char *ret = fgets(buffer, sizeof(buffer) - 1, file); + char* ret = fgets(buffer, sizeof(buffer) - 1, file); if (ret == NULL) { - perror("Could not read stat file"); + SERVER_LOG_ERROR << "Could not read stat file"; fclose(file); return total_time_array; } - sscanf(buffer, - "cpu %16lu %16lu %16lu %16lu %16lu %16lu %16lu %16lu %16lu %16lu", - &user, &nice, &system, &idle, &iowait, &irq, &softirq, &steal, &guest, &guestnice); + sscanf(buffer, "cpu %16lu %16lu %16lu %16lu %16lu %16lu %16lu %16lu %16lu %16lu", &user, &nice, &system, &idle, + &iowait, &irq, &softirq, &steal, &guest, &guestnice); work_time_array.push_back(user + nice + system); total_time_array.push_back(user + nice + system + idle + iowait + irq + softirq + steal); @@ -173,19 +184,19 @@ SystemInfo::getTotalCpuTime(std::vector &work_time_array) { double SystemInfo::CPUPercent() { - if (!initialized_) Init(); + if (!initialized_) { + Init(); + } struct tms time_sample; clock_t now; double percent; now = times(&time_sample); - if (now <= last_cpu_ || time_sample.tms_stime < last_sys_cpu_ || - time_sample.tms_utime < last_user_cpu_) { - //Overflow detection. Just skip this value. + if (now <= last_cpu_ || time_sample.tms_stime < last_sys_cpu_ || time_sample.tms_utime < last_user_cpu_) { + // Overflow detection. Just skip this value. percent = -1.0; } else { - percent = (time_sample.tms_stime - last_sys_cpu_) + - (time_sample.tms_utime - last_user_cpu_); + percent = (time_sample.tms_stime - last_sys_cpu_) + (time_sample.tms_utime - last_user_cpu_); percent /= (now - last_cpu_); percent *= 100; } @@ -199,7 +210,8 @@ SystemInfo::CPUPercent() { std::vector SystemInfo::GPUMemoryTotal() { // get GPU usage percent - if (!initialized_) Init(); + if (!initialized_) + Init(); std::vector result; nvmlMemory_t nvmlMemory; for (int i = 0; i < num_device_; ++i) { @@ -213,7 +225,8 @@ SystemInfo::GPUMemoryTotal() { std::vector SystemInfo::GPUTemperature() { - if (!initialized_) Init(); + if (!initialized_) + Init(); std::vector result; for (int i = 0; i < num_device_; i++) { nvmlDevice_t device; @@ -228,24 +241,46 @@ SystemInfo::GPUTemperature() { std::vector SystemInfo::CPUTemperature() { std::vector result; - for (int i = 0; i <= num_physical_processors_; ++i) { - std::string path = "/sys/class/thermal/thermal_zone" + std::to_string(i) + "/temp"; - FILE *file = fopen(path.data(), "r"); - if (file == NULL) { - perror("Could not open thermal file"); - return result; + std::string path = "/sys/class/hwmon/"; + + DIR* dir = NULL; + dir = opendir(path.c_str()); + if (!dir) { + SERVER_LOG_ERROR << "Could not open hwmon directory"; + return result; + } + + struct dirent* ptr = NULL; + while ((ptr = readdir(dir)) != NULL) { + std::string filename(path); + filename.append(ptr->d_name); + + char buf[100]; + if (readlink(filename.c_str(), buf, 100) != -1) { + std::string m(buf); + if (m.find("coretemp") != std::string::npos) { + std::string object = filename; + object += "/temp1_input"; + FILE* file = fopen(object.c_str(), "r"); + if (file == nullptr) { + SERVER_LOG_ERROR << "Could not open temperature file"; + return result; + } + float temp; + fscanf(file, "%f", &temp); + result.push_back(temp / 1000); + } } - float temp; - fscanf(file, "%f", &temp); - result.push_back(temp / 1000); - fclose(file); } + closedir(dir); + return result; } std::vector SystemInfo::GPUMemoryUsed() { // get GPU memory used - if (!initialized_) Init(); + if (!initialized_) + Init(); std::vector result; nvmlMemory_t nvmlMemory; @@ -261,12 +296,12 @@ SystemInfo::GPUMemoryUsed() { std::pair SystemInfo::Octets() { pid_t pid = getpid(); -// const std::string filename = "/proc/"+std::to_string(pid)+"/net/netstat"; + // const std::string filename = "/proc/"+std::to_string(pid)+"/net/netstat"; const std::string filename = "/proc/net/netstat"; std::ifstream file(filename); std::string lastline = ""; std::string line = ""; - while (file) { + while (true) { getline(file, line); if (file.fail()) { break; @@ -293,6 +328,5 @@ SystemInfo::Octets() { return res; } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/metrics/SystemInfo.h b/cpp/src/metrics/SystemInfo.h index 802cbb0cce80956cfe71813eb2cf8499c8d2fb75..0176475232449c2fd5148197c18874af9626ce91 100644 --- a/cpp/src/metrics/SystemInfo.h +++ b/cpp/src/metrics/SystemInfo.h @@ -15,22 +15,20 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include +#include +#include +#include #include #include +#include #include -#include -#include -#include #include #include -#include #include +#include -namespace zilliz { namespace milvus { namespace server { @@ -43,72 +41,93 @@ class SystemInfo { std::chrono::system_clock::time_point net_time_ = std::chrono::system_clock::now(); int num_processors_ = 0; int num_physical_processors_ = 0; - //number of GPU + // number of GPU uint32_t num_device_ = 0; uint64_t in_octets_ = 0; uint64_t out_octets_ = 0; bool initialized_ = false; public: - static SystemInfo & + static SystemInfo& GetInstance() { static SystemInfo instance; return instance; } - void Init(); + void + Init(); - int num_processor() const { + int + num_processor() const { return num_processors_; } - int num_physical_processors() const { + int + num_physical_processors() const { return num_physical_processors_; } - uint32_t num_device() const { + uint32_t + num_device() const { return num_device_; } - uint64_t get_inoctets() { + uint64_t + get_inoctets() { return in_octets_; } - uint64_t get_octets() { + uint64_t + get_octets() { return out_octets_; } - std::chrono::system_clock::time_point get_nettime() { + std::chrono::system_clock::time_point + get_nettime() { return net_time_; } - void set_inoctets(uint64_t value) { + void + set_inoctets(uint64_t value) { in_octets_ = value; } - void set_outoctets(uint64_t value) { + void + set_outoctets(uint64_t value) { out_octets_ = value; } - void set_nettime() { + void + set_nettime() { net_time_ = std::chrono::system_clock::now(); } - uint64_t ParseLine(char *line); - uint64_t GetPhysicalMemory(); - uint64_t GetProcessUsedMemory(); - double MemoryPercent(); - double CPUPercent(); - std::pair Octets(); - std::vector GPUMemoryTotal(); - std::vector GPUMemoryUsed(); - - std::vector CPUCorePercent(); - std::vector getTotalCpuTime(std::vector &workTime); - std::vector GPUTemperature(); - std::vector CPUTemperature(); + uint64_t + ParseLine(char* line); + uint64_t + GetPhysicalMemory(); + uint64_t + GetProcessUsedMemory(); + double + MemoryPercent(); + double + CPUPercent(); + std::pair + Octets(); + std::vector + GPUMemoryTotal(); + std::vector + GPUMemoryUsed(); + + std::vector + CPUCorePercent(); + std::vector + getTotalCpuTime(std::vector& workTime); + std::vector + GPUTemperature(); + std::vector + CPUTemperature(); }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/scheduler/Algorithm.cpp b/cpp/src/scheduler/Algorithm.cpp index 10a585304a681f01e594b3985d8e8624ebbb5d9f..44f83742c25633902c5f00a249a4ec8c359d7568 100644 --- a/cpp/src/scheduler/Algorithm.cpp +++ b/cpp/src/scheduler/Algorithm.cpp @@ -15,24 +15,20 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/Algorithm.h" #include #include #include -namespace zilliz { namespace milvus { namespace scheduler { constexpr uint64_t MAXINT = std::numeric_limits::max(); uint64_t -ShortestPath(const ResourcePtr &src, - const ResourcePtr &dest, - const ResourceMgrPtr &res_mgr, - std::vector &path) { +ShortestPath(const ResourcePtr& src, const ResourcePtr& dest, const ResourceMgrPtr& res_mgr, + std::vector& path) { std::vector> paths; uint64_t num_of_resources = res_mgr->GetAllResources().size(); @@ -43,7 +39,7 @@ ShortestPath(const ResourcePtr &src, name_id_map.insert(std::make_pair(res_mgr->GetAllResources().at(i)->name(), i)); } - std::vector > dis_matrix; + std::vector> dis_matrix; dis_matrix.resize(num_of_resources); for (uint64_t i = 0; i < num_of_resources; ++i) { dis_matrix[i].resize(num_of_resources); @@ -55,11 +51,11 @@ ShortestPath(const ResourcePtr &src, std::vector vis(num_of_resources, false); std::vector dis(num_of_resources, MAXINT); - for (auto &res : res_mgr->GetAllResources()) { + for (auto& res : res_mgr->GetAllResources()) { auto cur_node = std::static_pointer_cast(res); auto cur_neighbours = cur_node->GetNeighbours(); - for (auto &neighbour : cur_neighbours) { + for (auto& neighbour : cur_neighbours) { auto neighbour_res = std::static_pointer_cast(neighbour.neighbour_node.lock()); dis_matrix[name_id_map.at(res->name())][name_id_map.at(neighbour_res->name())] = neighbour.connection.transport_cost(); @@ -107,6 +103,5 @@ ShortestPath(const ResourcePtr &src, return dis[name_id_map.at(dest->name())]; } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/Algorithm.h b/cpp/src/scheduler/Algorithm.h index d7e0233ba0eb19f4c2e223e863ae53e5930a0451..69ff8f3a704fa466d5eddae74fb381a4ef0377e0 100644 --- a/cpp/src/scheduler/Algorithm.h +++ b/cpp/src/scheduler/Algorithm.h @@ -15,23 +15,18 @@ // specific language governing permissions and limitations // under the License. - -#include "resource/Resource.h" #include "ResourceMgr.h" +#include "resource/Resource.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { uint64_t -ShortestPath(const ResourcePtr &src, - const ResourcePtr &dest, - const ResourceMgrPtr &res_mgr, - std::vector &path); +ShortestPath(const ResourcePtr& src, const ResourcePtr& dest, const ResourceMgrPtr& res_mgr, + std::vector& path); -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/Definition.h b/cpp/src/scheduler/Definition.h index ce41aca48d958377346f7f64c6a065fc284b129c..162988e90acca5f2682b1f880e08694395f97095 100644 --- a/cpp/src/scheduler/Definition.h +++ b/cpp/src/scheduler/Definition.h @@ -15,22 +15,21 @@ // specific language governing permissions and limitations // under the License. -#include -#include +#include +#include #include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include -#include "db/meta/MetaTypes.h" #include "db/engine/EngineFactory.h" #include "db/engine/ExecutionEngine.h" +#include "db/meta/MetaTypes.h" -namespace zilliz { namespace milvus { namespace scheduler { @@ -42,6 +41,5 @@ using EngineFactory = engine::EngineFactory; using EngineType = engine::EngineType; using MetricType = engine::MetricType; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/JobMgr.cpp b/cpp/src/scheduler/JobMgr.cpp index 0406e98c4908fdd6bb56da9d27468a3457cfc11a..bf22f7be2ef0a5b443455f90c3c5f84295ab543d 100644 --- a/cpp/src/scheduler/JobMgr.cpp +++ b/cpp/src/scheduler/JobMgr.cpp @@ -16,17 +16,15 @@ // under the License. #include "scheduler/JobMgr.h" -#include "task/Task.h" #include "TaskCreator.h" +#include "task/Task.h" #include -namespace zilliz { namespace milvus { namespace scheduler { -JobMgr::JobMgr(ResourceMgrPtr res_mgr) - : res_mgr_(std::move(res_mgr)) { +JobMgr::JobMgr(ResourceMgrPtr res_mgr) : res_mgr_(std::move(res_mgr)) { } void @@ -47,7 +45,7 @@ JobMgr::Stop() { } void -JobMgr::Put(const JobPtr &job) { +JobMgr::Put(const JobPtr& job) { { std::lock_guard lock(mutex_); queue_.push(job); @@ -59,9 +57,7 @@ void JobMgr::worker_function() { while (running_) { std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { - return !queue_.empty(); - }); + cv_.wait(lock, [this] { return !queue_.empty(); }); auto job = queue_.front(); queue_.pop(); lock.unlock(); @@ -70,22 +66,19 @@ JobMgr::worker_function() { } auto tasks = build_task(job); - auto disk_list = res_mgr_->GetDiskResources(); - if (!disk_list.empty()) { - if (auto disk = disk_list[0].lock()) { - for (auto &task : tasks) { - disk->task_table().Put(task); - } + // disk resources NEVER be empty. + if (auto disk = res_mgr_->GetDiskResources()[0].lock()) { + for (auto& task : tasks) { + disk->task_table().Put(task); } } } } std::vector -JobMgr::build_task(const JobPtr &job) { +JobMgr::build_task(const JobPtr& job) { return TaskCreator::Create(job); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/JobMgr.h b/cpp/src/scheduler/JobMgr.h index 49ba9154e37958c2bdb182708ece10a48b8b6b3a..4340c9e616a6a353913cfcf8596f71c7ed969da0 100644 --- a/cpp/src/scheduler/JobMgr.h +++ b/cpp/src/scheduler/JobMgr.h @@ -16,22 +16,21 @@ // under the License. #pragma once -#include -#include +#include +#include #include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include +#include "ResourceMgr.h" #include "job/Job.h" #include "task/Task.h" -#include "ResourceMgr.h" -namespace zilliz { namespace milvus { namespace scheduler { @@ -47,14 +46,14 @@ class JobMgr { public: void - Put(const JobPtr &job); + Put(const JobPtr& job); private: void worker_function(); std::vector - build_task(const JobPtr &job); + build_task(const JobPtr& job); private: bool running_ = false; @@ -70,6 +69,5 @@ class JobMgr { using JobMgrPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/ResourceFactory.cpp b/cpp/src/scheduler/ResourceFactory.cpp index de9b5bc717c6bab4cc59c3ff1bd904a62a9d0894..fad8571b61bfe693dcca11991b4ed9b9f232e10f 100644 --- a/cpp/src/scheduler/ResourceFactory.cpp +++ b/cpp/src/scheduler/ResourceFactory.cpp @@ -15,18 +15,13 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/ResourceFactory.h" -namespace zilliz { namespace milvus { namespace scheduler { std::shared_ptr -ResourceFactory::Create(const std::string &name, - const std::string &type, - uint64_t device_id, - bool enable_loader, +ResourceFactory::Create(const std::string& name, const std::string& type, uint64_t device_id, bool enable_loader, bool enable_executor) { if (type == "DISK") { return std::make_shared(name, device_id, enable_loader, enable_executor); @@ -39,6 +34,5 @@ ResourceFactory::Create(const std::string &name, } } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/ResourceFactory.h b/cpp/src/scheduler/ResourceFactory.h index f7a47ef1e565e0fa3be527c26546cc5e6885bc04..3290cb023c9fe48945a9c6db0135e2228118af87 100644 --- a/cpp/src/scheduler/ResourceFactory.h +++ b/cpp/src/scheduler/ResourceFactory.h @@ -17,28 +17,23 @@ #pragma once -#include #include +#include -#include "resource/Resource.h" #include "resource/CpuResource.h" -#include "resource/GpuResource.h" #include "resource/DiskResource.h" +#include "resource/GpuResource.h" +#include "resource/Resource.h" -namespace zilliz { namespace milvus { namespace scheduler { class ResourceFactory { public: static std::shared_ptr - Create(const std::string &name, - const std::string &type, - uint64_t device_id, - bool enable_loader = true, + Create(const std::string& name, const std::string& type, uint64_t device_id, bool enable_loader = true, bool enable_executor = true); }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/ResourceMgr.cpp b/cpp/src/scheduler/ResourceMgr.cpp index 6067b2eb01d332139fa0bf1b3281c81f17125d8b..6e839062ef116341e843e3ad37f04877bbfa9d4f 100644 --- a/cpp/src/scheduler/ResourceMgr.cpp +++ b/cpp/src/scheduler/ResourceMgr.cpp @@ -19,14 +19,19 @@ #include "scheduler/ResourceMgr.h" #include "utils/Log.h" -namespace zilliz { namespace milvus { namespace scheduler { void ResourceMgr::Start() { + if (not check_resource_valid()) { + ENGINE_LOG_ERROR << "Resources invalid, cannot start ResourceMgr."; + ENGINE_LOG_ERROR << Dump(); + return; + } + std::lock_guard lck(resources_mutex_); - for (auto &resource : resources_) { + for (auto& resource : resources_) { resource->Start(); } running_ = true; @@ -44,13 +49,13 @@ ResourceMgr::Stop() { worker_thread_.join(); std::lock_guard lck(resources_mutex_); - for (auto &resource : resources_) { + for (auto& resource : resources_) { resource->Stop(); } } ResourceWPtr -ResourceMgr::Add(ResourcePtr &&resource) { +ResourceMgr::Add(ResourcePtr&& resource) { ResourceWPtr ret(resource); std::lock_guard lck(resources_mutex_); @@ -61,8 +66,20 @@ ResourceMgr::Add(ResourcePtr &&resource) { resource->RegisterSubscriber(std::bind(&ResourceMgr::post_event, this, std::placeholders::_1)); - if (resource->type() == ResourceType::DISK) { - disk_resources_.emplace_back(ResourceWPtr(resource)); + switch (resource->type()) { + case ResourceType::DISK: { + disk_resources_.emplace_back(ResourceWPtr(resource)); + break; + } + case ResourceType::CPU: { + cpu_resources_.emplace_back(ResourceWPtr(resource)); + break; + } + case ResourceType::GPU: { + gpu_resources_.emplace_back(ResourceWPtr(resource)); + break; + } + default: { break; } } resources_.emplace_back(resource); @@ -70,13 +87,13 @@ ResourceMgr::Add(ResourcePtr &&resource) { } bool -ResourceMgr::Connect(const std::string &name1, const std::string &name2, Connection &connection) { +ResourceMgr::Connect(const std::string& name1, const std::string& name2, Connection& connection) { auto res1 = GetResource(name1); auto res2 = GetResource(name2); if (res1 && res2) { res1->AddNeighbour(std::static_pointer_cast(res2), connection); - // TODO: enable when task balance supported -// res2->AddNeighbour(std::static_pointer_cast(res1), connection); + // TODO(wxyu): enable when task balance supported + // res2->AddNeighbour(std::static_pointer_cast(res1), connection); return true; } return false; @@ -85,14 +102,20 @@ ResourceMgr::Connect(const std::string &name1, const std::string &name2, Connect void ResourceMgr::Clear() { std::lock_guard lck(resources_mutex_); + if (running_) { + ENGINE_LOG_ERROR << "ResourceMgr is running, cannot clear."; + return; + } disk_resources_.clear(); + cpu_resources_.clear(); + gpu_resources_.clear(); resources_.clear(); } std::vector ResourceMgr::GetComputeResources() { std::vector result; - for (auto &resource : resources_) { + for (auto& resource : resources_) { if (resource->HasExecutor()) { result.emplace_back(resource); } @@ -102,7 +125,7 @@ ResourceMgr::GetComputeResources() { ResourcePtr ResourceMgr::GetResource(ResourceType type, uint64_t device_id) { - for (auto &resource : resources_) { + for (auto& resource : resources_) { if (resource->type() == type && resource->device_id() == device_id) { return resource; } @@ -111,8 +134,8 @@ ResourceMgr::GetResource(ResourceType type, uint64_t device_id) { } ResourcePtr -ResourceMgr::GetResource(const std::string &name) { - for (auto &resource : resources_) { +ResourceMgr::GetResource(const std::string& name) { + for (auto& resource : resources_) { if (resource->name() == name) { return resource; } @@ -128,7 +151,7 @@ ResourceMgr::GetNumOfResource() const { uint64_t ResourceMgr::GetNumOfComputeResource() const { uint64_t count = 0; - for (auto &res : resources_) { + for (auto& res : resources_) { if (res->HasExecutor()) { ++count; } @@ -139,7 +162,7 @@ ResourceMgr::GetNumOfComputeResource() const { uint64_t ResourceMgr::GetNumGpuResource() const { uint64_t num = 0; - for (auto &res : resources_) { + for (auto& res : resources_) { if (res->type() == ResourceType::GPU) { num++; } @@ -149,21 +172,21 @@ ResourceMgr::GetNumGpuResource() const { std::string ResourceMgr::Dump() { - std::string str = "ResourceMgr contains " + std::to_string(resources_.size()) + " resources.\n"; + std::stringstream ss; + ss << "ResourceMgr contains " << resources_.size() << " resources." << std::endl; - for (uint64_t i = 0; i < resources_.size(); ++i) { - str += "Resource No." + std::to_string(i) + ":\n"; - //str += resources_[i]->Dump(); + for (auto& res : resources_) { + ss << res->Dump(); } - return str; + return ss.str(); } std::string ResourceMgr::DumpTaskTables() { std::stringstream ss; ss << ">>>>>>>>>>>>>>>ResourceMgr::DumpTaskTable<<<<<<<<<<<<<<<" << std::endl; - for (auto &resource : resources_) { + for (auto& resource : resources_) { ss << resource->Dump() << std::endl; ss << resource->task_table().Dump(); ss << resource->Dump() << std::endl << std::endl; @@ -171,8 +194,42 @@ ResourceMgr::DumpTaskTables() { return ss.str(); } +bool +ResourceMgr::check_resource_valid() { + { + // TODO: check one disk-resource, one cpu-resource, zero or more gpu-resource; + if (GetDiskResources().size() != 1) { + return false; + } + if (GetCpuResources().size() != 1) { + return false; + } + } + + { + // TODO: one compute-resource at least; + if (GetNumOfComputeResource() < 1) { + return false; + } + } + + { + // TODO: check disk only connect with cpu + } + + { + // TODO: check gpu only connect with cpu + } + + { + // TODO: check if exists isolated node + } + + return true; +} + void -ResourceMgr::post_event(const EventPtr &event) { +ResourceMgr::post_event(const EventPtr& event) { { std::lock_guard lock(event_mutex_); queue_.emplace(event); @@ -184,9 +241,7 @@ void ResourceMgr::event_process() { while (running_) { std::unique_lock lock(event_mutex_); - event_cv_.wait(lock, [this] { - return !queue_.empty(); - }); + event_cv_.wait(lock, [this] { return !queue_.empty(); }); auto event = queue_.front(); queue_.pop(); @@ -201,6 +256,5 @@ ResourceMgr::event_process() { } } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/ResourceMgr.h b/cpp/src/scheduler/ResourceMgr.h index d03408a7df90954307bc87f8beb8e6a5b4c53a77..7a8e1ca4ca90aa791212def599b1e411ba444cca 100644 --- a/cpp/src/scheduler/ResourceMgr.h +++ b/cpp/src/scheduler/ResourceMgr.h @@ -17,18 +17,17 @@ #pragma once -#include -#include +#include #include #include #include +#include #include -#include +#include #include "resource/Resource.h" #include "utils/Log.h" -namespace zilliz { namespace milvus { namespace scheduler { @@ -45,10 +44,10 @@ class ResourceMgr { Stop(); ResourceWPtr - Add(ResourcePtr &&resource); + Add(ResourcePtr&& resource); bool - Connect(const std::string &res1, const std::string &res2, Connection &connection); + Connect(const std::string& name1, const std::string& name2, Connection& connection); void Clear(); @@ -60,12 +59,22 @@ class ResourceMgr { public: /******** Management Interface ********/ - inline std::vector & + inline std::vector& GetDiskResources() { return disk_resources_; } - // TODO: why return shared pointer + inline std::vector& + GetCpuResources() { + return cpu_resources_; + } + + inline std::vector& + GetGpuResources() { + return gpu_resources_; + } + + // TODO(wxyu): why return shared pointer inline std::vector GetAllResources() { return resources_; @@ -78,7 +87,7 @@ class ResourceMgr { GetResource(ResourceType type, uint64_t device_id); ResourcePtr - GetResource(const std::string &name); + GetResource(const std::string& name); uint64_t GetNumOfResource() const; @@ -90,7 +99,7 @@ class ResourceMgr { GetNumGpuResource() const; public: - // TODO: add stats interface(low) + // TODO(wxyu): add stats interface(low) public: /******** Utility Functions ********/ @@ -101,8 +110,11 @@ class ResourceMgr { DumpTaskTables(); private: + bool + check_resource_valid(); + void - post_event(const EventPtr &event); + post_event(const EventPtr& event); void event_process(); @@ -111,6 +123,8 @@ class ResourceMgr { bool running_ = false; std::vector disk_resources_; + std::vector cpu_resources_; + std::vector gpu_resources_; std::vector resources_; mutable std::mutex resources_mutex_; @@ -125,6 +139,5 @@ class ResourceMgr { using ResourceMgrPtr = std::shared_ptr; using ResourceMgrWPtr = std::weak_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/SchedInst.cpp b/cpp/src/scheduler/SchedInst.cpp index 71b40de9eed2e6ca42f8f5a56f4c3e66a32925cf..b9edbca001d6fdebae98f4ac4280bf2842e67c72 100644 --- a/cpp/src/scheduler/SchedInst.cpp +++ b/cpp/src/scheduler/SchedInst.cpp @@ -15,19 +15,17 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/SchedInst.h" -#include "server/Config.h" #include "ResourceFactory.h" -#include "knowhere/index/vector_index/IndexGPUIVF.h" #include "Utils.h" +#include "knowhere/index/vector_index/IndexGPUIVF.h" +#include "server/Config.h" -#include #include -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -42,7 +40,7 @@ std::mutex JobMgrInst::mutex_; void load_simple_config() { - server::Config &config = server::Config::GetInstance(); + server::Config& config = server::Config::GetInstance(); std::string mode; config.GetResourceConfigMode(mode); std::vector pool; @@ -50,7 +48,7 @@ load_simple_config() { bool cpu = false; std::set gpu_ids; - for (auto &resource : pool) { + for (auto& resource : pool) { if (resource == "cpu") { cpu = true; break; @@ -78,7 +76,7 @@ load_simple_config() { ResMgrInst::GetInstance()->Connect("disk", "cpu", io); auto pcie = Connection("pcie", 12000); - for (auto &gpu_id : gpu_ids) { + for (auto& gpu_id : gpu_ids) { ResMgrInst::GetInstance()->Add(ResourceFactory::Create(std::to_string(gpu_id), "GPU", gpu_id, true, true)); ResMgrInst::GetInstance()->Connect("cpu", std::to_string(gpu_id), io); } @@ -87,77 +85,77 @@ load_simple_config() { void load_advance_config() { -// try { -// server::ConfigNode &config = server::Config::GetInstance().GetConfig(server::CONFIG_RESOURCE); -// -// if (config.GetChildren().empty()) throw "resource_config null exception"; -// -// auto resources = config.GetChild(server::CONFIG_RESOURCES).GetChildren(); -// -// if (resources.empty()) throw "Children of resource_config null exception"; -// -// for (auto &resource : resources) { -// auto &resname = resource.first; -// auto &resconf = resource.second; -// auto type = resconf.GetValue(server::CONFIG_RESOURCE_TYPE); -//// auto memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_MEMORY); -// auto device_id = resconf.GetInt64Value(server::CONFIG_RESOURCE_DEVICE_ID); -//// auto enable_loader = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_LOADER); -// auto enable_loader = true; -// auto enable_executor = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_EXECUTOR); -// auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY); -// auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY); -// auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM); -// -// auto res = ResMgrInst::GetInstance()->Add(ResourceFactory::Create(resname, -// type, -// device_id, -// enable_loader, -// enable_executor)); -// -// if (res.lock()->type() == ResourceType::GPU) { -// auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY, 300); -// auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY, 300); -// auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM, 2); -// pinned_memory = 1024 * 1024 * pinned_memory; -// temp_memory = 1024 * 1024 * temp_memory; -// knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, -// pinned_memory, -// temp_memory, -// resource_num); -// } -// } -// -// knowhere::FaissGpuResourceMgr::GetInstance().InitResource(); -// -// auto connections = config.GetChild(server::CONFIG_RESOURCE_CONNECTIONS).GetChildren(); -// if (connections.empty()) throw "connections config null exception"; -// for (auto &conn : connections) { -// auto &connect_name = conn.first; -// auto &connect_conf = conn.second; -// auto connect_speed = connect_conf.GetInt64Value(server::CONFIG_SPEED_CONNECTIONS); -// auto connect_endpoint = connect_conf.GetValue(server::CONFIG_ENDPOINT_CONNECTIONS); -// -// std::string delimiter = "==="; -// std::string left = connect_endpoint.substr(0, connect_endpoint.find(delimiter)); -// std::string right = connect_endpoint.substr(connect_endpoint.find(delimiter) + 3, -// connect_endpoint.length()); -// -// auto connection = Connection(connect_name, connect_speed); -// ResMgrInst::GetInstance()->Connect(left, right, connection); -// } -// } catch (const char *msg) { -// SERVER_LOG_ERROR << msg; -// // TODO: throw exception instead -// exit(-1); -//// throw std::exception(); -// } + // try { + // server::ConfigNode &config = server::Config::GetInstance().GetConfig(server::CONFIG_RESOURCE); + // + // if (config.GetChildren().empty()) throw "resource_config null exception"; + // + // auto resources = config.GetChild(server::CONFIG_RESOURCES).GetChildren(); + // + // if (resources.empty()) throw "Children of resource_config null exception"; + // + // for (auto &resource : resources) { + // auto &resname = resource.first; + // auto &resconf = resource.second; + // auto type = resconf.GetValue(server::CONFIG_RESOURCE_TYPE); + //// auto memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_MEMORY); + // auto device_id = resconf.GetInt64Value(server::CONFIG_RESOURCE_DEVICE_ID); + //// auto enable_loader = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_LOADER); + // auto enable_loader = true; + // auto enable_executor = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_EXECUTOR); + // auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY); + // auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY); + // auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM); + // + // auto res = ResMgrInst::GetInstance()->Add(ResourceFactory::Create(resname, + // type, + // device_id, + // enable_loader, + // enable_executor)); + // + // if (res.lock()->type() == ResourceType::GPU) { + // auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY, 300); + // auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY, 300); + // auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM, 2); + // pinned_memory = 1024 * 1024 * pinned_memory; + // temp_memory = 1024 * 1024 * temp_memory; + // knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, + // pinned_memory, + // temp_memory, + // resource_num); + // } + // } + // + // knowhere::FaissGpuResourceMgr::GetInstance().InitResource(); + // + // auto connections = config.GetChild(server::CONFIG_RESOURCE_CONNECTIONS).GetChildren(); + // if (connections.empty()) throw "connections config null exception"; + // for (auto &conn : connections) { + // auto &connect_name = conn.first; + // auto &connect_conf = conn.second; + // auto connect_speed = connect_conf.GetInt64Value(server::CONFIG_SPEED_CONNECTIONS); + // auto connect_endpoint = connect_conf.GetValue(server::CONFIG_ENDPOINT_CONNECTIONS); + // + // std::string delimiter = "==="; + // std::string left = connect_endpoint.substr(0, connect_endpoint.find(delimiter)); + // std::string right = connect_endpoint.substr(connect_endpoint.find(delimiter) + 3, + // connect_endpoint.length()); + // + // auto connection = Connection(connect_name, connect_speed); + // ResMgrInst::GetInstance()->Connect(left, right, connection); + // } + // } catch (const char *msg) { + // SERVER_LOG_ERROR << msg; + // // TODO(wxyu): throw exception instead + // exit(-1); + //// throw std::exception(); + // } } void StartSchedulerService() { load_simple_config(); -// load_advance_config(); + // load_advance_config(); ResMgrInst::GetInstance()->Start(); SchedInst::GetInstance()->Start(); JobMgrInst::GetInstance()->Start(); @@ -170,6 +168,5 @@ StopSchedulerService() { ResMgrInst::GetInstance()->Stop(); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/SchedInst.h b/cpp/src/scheduler/SchedInst.h index 4cca6ec5a957f28a589a3f715120a93177669c6d..dc8e4ed478fb84ac964c5e4b7315c2df1d01a6da 100644 --- a/cpp/src/scheduler/SchedInst.h +++ b/cpp/src/scheduler/SchedInst.h @@ -17,14 +17,13 @@ #pragma once +#include "JobMgr.h" #include "ResourceMgr.h" #include "Scheduler.h" -#include "JobMgr.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -88,6 +87,5 @@ StartSchedulerService(); void StopSchedulerService(); -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/Scheduler.cpp b/cpp/src/scheduler/Scheduler.cpp index 24f7bfe73b3e62f2a7b53b8fe3a2f96f6aa60da5..3a82a1b361f1a64870e8922bc1bd4224a8588cec 100644 --- a/cpp/src/scheduler/Scheduler.cpp +++ b/cpp/src/scheduler/Scheduler.cpp @@ -16,20 +16,17 @@ // under the License. #include "scheduler/Scheduler.h" +#include "Algorithm.h" +#include "action/Action.h" #include "cache/GpuCacheMgr.h" #include "event/LoadCompletedEvent.h" -#include "action/Action.h" -#include "Algorithm.h" #include -namespace zilliz { namespace milvus { namespace scheduler { -Scheduler::Scheduler(ResourceMgrWPtr res_mgr) - : running_(false), - res_mgr_(std::move(res_mgr)) { +Scheduler::Scheduler(ResourceMgrWPtr res_mgr) : running_(false), res_mgr_(std::move(res_mgr)) { if (auto mgr = res_mgr_.lock()) { mgr->RegisterSubscriber(std::bind(&Scheduler::PostEvent, this, std::placeholders::_1)); } @@ -61,7 +58,7 @@ Scheduler::Stop() { } void -Scheduler::PostEvent(const EventPtr &event) { +Scheduler::PostEvent(const EventPtr& event) { { std::lock_guard lock(event_mutex_); event_queue_.push(event); @@ -78,9 +75,7 @@ void Scheduler::worker_function() { while (running_) { std::unique_lock lock(event_mutex_); - event_cv_.wait(lock, [this] { - return !event_queue_.empty(); - }); + event_cv_.wait(lock, [this] { return !event_queue_.empty(); }); auto event = event_queue_.front(); event_queue_.pop(); if (event == nullptr) { @@ -92,14 +87,14 @@ Scheduler::worker_function() { } void -Scheduler::Process(const EventPtr &event) { +Scheduler::Process(const EventPtr& event) { auto process_event = event_register_.at(static_cast(event->Type())); process_event(event); } -// TODO: refactor the function +// TODO(wxyu): refactor the function void -Scheduler::OnLoadCompleted(const EventPtr &event) { +Scheduler::OnLoadCompleted(const EventPtr& event) { auto load_completed_event = std::static_pointer_cast(event); if (auto resource = event->resource_.lock()) { resource->WakeupExecutor(); @@ -118,31 +113,28 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { Action::PushTaskToAllNeighbour(load_completed_event->task_table_item_->task, resource); break; } - default: { - break; - } + default: { break; } } } } void -Scheduler::OnStartUp(const EventPtr &event) { +Scheduler::OnStartUp(const EventPtr& event) { if (auto resource = event->resource_.lock()) { resource->WakeupLoader(); } } void -Scheduler::OnFinishTask(const EventPtr &event) { +Scheduler::OnFinishTask(const EventPtr& event) { } void -Scheduler::OnTaskTableUpdated(const EventPtr &event) { +Scheduler::OnTaskTableUpdated(const EventPtr& event) { if (auto resource = event->resource_.lock()) { resource->WakeupLoader(); } } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/Scheduler.h b/cpp/src/scheduler/Scheduler.h index 073f7eeb0d8d9617a01cf54ca82e7ba98797e420..5b222cc41a7f5ac792173ca5095d3f5d7df107ee 100644 --- a/cpp/src/scheduler/Scheduler.h +++ b/cpp/src/scheduler/Scheduler.h @@ -18,27 +18,26 @@ #pragma once #include -#include #include -#include #include +#include +#include #include -#include "resource/Resource.h" #include "ResourceMgr.h" +#include "resource/Resource.h" #include "utils/Log.h" -namespace zilliz { namespace milvus { namespace scheduler { -// TODO: refactor, not friendly to unittest, logical in framework code +// TODO(wxyu): refactor, not friendly to unittest, logical in framework code class Scheduler { public: explicit Scheduler(ResourceMgrWPtr res_mgr); - Scheduler(const Scheduler &) = delete; - Scheduler(Scheduler &&) = delete; + Scheduler(const Scheduler&) = delete; + Scheduler(Scheduler&&) = delete; /* * Start worker thread; @@ -56,7 +55,7 @@ class Scheduler { * Post event to scheduler event queue; */ void - PostEvent(const EventPtr &event); + PostEvent(const EventPtr& event); /* * Dump as string; @@ -74,7 +73,7 @@ class Scheduler { * Pull task from neighbours; */ void - OnStartUp(const EventPtr &event); + OnStartUp(const EventPtr& event); /* * Process finish task events; @@ -83,7 +82,7 @@ class Scheduler { * Pull task from neighbours; */ void - OnFinishTask(const EventPtr &event); + OnFinishTask(const EventPtr& event); /* * Process copy completed events; @@ -93,7 +92,7 @@ class Scheduler { * Pull task from neighbours; */ void - OnLoadCompleted(const EventPtr &event); + OnLoadCompleted(const EventPtr& event); /* * Process task table updated events, which happened on task_table->put; @@ -102,14 +101,14 @@ class Scheduler { * Push task to neighbours; */ void - OnTaskTableUpdated(const EventPtr &event); + OnTaskTableUpdated(const EventPtr& event); private: /* * Dispatch event to event handler; */ void - Process(const EventPtr &event); + Process(const EventPtr& event); /* * Called by worker_thread_; @@ -131,6 +130,5 @@ class Scheduler { using SchedulerPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/TaskCreator.cpp b/cpp/src/scheduler/TaskCreator.cpp index 8b0378c646299924a6fca2e10b0d82472d61bf83..ee63c2c6b7763ea102ec4650aa5590965ca0ed74 100644 --- a/cpp/src/scheduler/TaskCreator.cpp +++ b/cpp/src/scheduler/TaskCreator.cpp @@ -16,15 +16,16 @@ // under the License. #include "scheduler/TaskCreator.h" +#include +#include "SchedInst.h" #include "scheduler/tasklabel/BroadcastLabel.h" #include "tasklabel/DefaultLabel.h" -namespace zilliz { namespace milvus { namespace scheduler { std::vector -TaskCreator::Create(const JobPtr &job) { +TaskCreator::Create(const JobPtr& job) { switch (job->type()) { case JobType::SEARCH: { return Create(std::static_pointer_cast(job)); @@ -32,19 +33,22 @@ TaskCreator::Create(const JobPtr &job) { case JobType::DELETE: { return Create(std::static_pointer_cast(job)); } + case JobType::BUILD: { + return Create(std::static_pointer_cast(job)); + } default: { - // TODO: error + // TODO(wxyu): error return std::vector(); } } } std::vector -TaskCreator::Create(const SearchJobPtr &job) { +TaskCreator::Create(const SearchJobPtr& job) { std::vector tasks; - for (auto &index_file : job->index_files()) { - auto task = std::make_shared(index_file.second); - task->label() = std::make_shared(); + for (auto& index_file : job->index_files()) { + auto label = std::make_shared(); + auto task = std::make_shared(index_file.second, label); task->job_ = job; tasks.emplace_back(task); } @@ -53,16 +57,30 @@ TaskCreator::Create(const SearchJobPtr &job) { } std::vector -TaskCreator::Create(const DeleteJobPtr &job) { +TaskCreator::Create(const DeleteJobPtr& job) { std::vector tasks; - auto task = std::make_shared(job); - task->label() = std::make_shared(); + auto label = std::make_shared(); + auto task = std::make_shared(job, label); task->job_ = job; tasks.emplace_back(task); return tasks; } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +std::vector +TaskCreator::Create(const BuildIndexJobPtr& job) { + std::vector tasks; + // TODO(yukun): remove "disk" hardcode here + ResourcePtr res_ptr = ResMgrInst::GetInstance()->GetResource("disk"); + + for (auto& to_index_file : job->to_index_files()) { + auto label = std::make_shared(std::weak_ptr(res_ptr)); + auto task = std::make_shared(to_index_file.second, label); + task->job_ = job; + tasks.emplace_back(task); + } + return tasks; +} + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/TaskCreator.h b/cpp/src/scheduler/TaskCreator.h index 81cb25010f307c7424375a2340e4b8debc7cfa1b..ef71d9a3d37b90f3d6f4cbc4e8eb608b4e812145 100644 --- a/cpp/src/scheduler/TaskCreator.h +++ b/cpp/src/scheduler/TaskCreator.h @@ -16,41 +16,43 @@ // under the License. #pragma once -#include -#include +#include +#include #include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include +#include "job/DeleteJob.h" #include "job/Job.h" #include "job/SearchJob.h" -#include "job/DeleteJob.h" -#include "task/Task.h" -#include "task/SearchTask.h" +#include "task/BuildIndexTask.h" #include "task/DeleteTask.h" +#include "task/SearchTask.h" +#include "task/Task.h" -namespace zilliz { namespace milvus { namespace scheduler { class TaskCreator { public: static std::vector - Create(const JobPtr &job); + Create(const JobPtr& job); public: static std::vector - Create(const SearchJobPtr &job); + Create(const SearchJobPtr& job); + + static std::vector + Create(const DeleteJobPtr& job); static std::vector - Create(const DeleteJobPtr &job); + Create(const BuildIndexJobPtr& job); }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/TaskTable.cpp b/cpp/src/scheduler/TaskTable.cpp index a7343ee5093c38e443a6c742c35bfb3a52845fb1..0d6742c649f19dbb3c25349bf089d80825970aab 100644 --- a/cpp/src/scheduler/TaskTable.cpp +++ b/cpp/src/scheduler/TaskTable.cpp @@ -15,36 +15,43 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/TaskTable.h" -#include "event/TaskTableUpdatedEvent.h" #include "Utils.h" +#include "event/TaskTableUpdatedEvent.h" -#include -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace scheduler { std::string ToString(TaskTableItemState state) { switch (state) { - case TaskTableItemState::INVALID: return "INVALID"; - case TaskTableItemState::START: return "START"; - case TaskTableItemState::LOADING: return "LOADING"; - case TaskTableItemState::LOADED: return "LOADED"; - case TaskTableItemState::EXECUTING: return "EXECUTING"; - case TaskTableItemState::EXECUTED: return "EXECUTED"; - case TaskTableItemState::MOVING: return "MOVING"; - case TaskTableItemState::MOVED: return "MOVED"; - default: return ""; + case TaskTableItemState::INVALID: + return "INVALID"; + case TaskTableItemState::START: + return "START"; + case TaskTableItemState::LOADING: + return "LOADING"; + case TaskTableItemState::LOADED: + return "LOADED"; + case TaskTableItemState::EXECUTING: + return "EXECUTING"; + case TaskTableItemState::EXECUTED: + return "EXECUTED"; + case TaskTableItemState::MOVING: + return "MOVING"; + case TaskTableItemState::MOVED: + return "MOVED"; + default: + return ""; } } std::string -ToString(const TaskTimestamp ×tamp) { +ToString(const TaskTimestamp& timestamp) { std::stringstream ss; ss << " subscriber) { @@ -117,7 +116,7 @@ class TaskTable { * Called by DBImpl; */ void - Put(std::vector &tasks); + Put(std::vector& tasks); /* * Return task table item reference; @@ -130,8 +129,8 @@ class TaskTable { * Remove sequence task which is DONE or MOVED from front; * Called by ? */ -// void -// Clear(); + // void + // Clear(); /* * Return true if task table empty, otherwise false; @@ -150,16 +149,17 @@ class TaskTable { } public: - TaskTableItemPtr & - operator[](uint64_t index) { + TaskTableItemPtr& operator[](uint64_t index) { return table_[index]; } - std::deque::iterator begin() { + std::deque::iterator + begin() { return table_.begin(); } - std::deque::iterator end() { + std::deque::iterator + end() { return table_.end(); } @@ -173,7 +173,7 @@ class TaskTable { public: /******** Action ********/ - // TODO: bool to Status + // TODO(wxyu): bool to Status /* * Load a task; * Set state loading; @@ -254,6 +254,5 @@ class TaskTable { uint64_t last_finish_ = -1; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/Utils.cpp b/cpp/src/scheduler/Utils.cpp index bb19950ffece8f6de4ad15e478e9ad7946989979..071f152227205f141079c6bb455b269358bf642c 100644 --- a/cpp/src/scheduler/Utils.cpp +++ b/cpp/src/scheduler/Utils.cpp @@ -15,13 +15,11 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/Utils.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -40,6 +38,5 @@ get_num_gpu() { return n_devices; } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/Utils.h b/cpp/src/scheduler/Utils.h index c69028f0fa64e2f0d57d1f82291ebe3c3523d587..e999e0fda3d460ac95299720c104d72385f49896 100644 --- a/cpp/src/scheduler/Utils.h +++ b/cpp/src/scheduler/Utils.h @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. - #include -namespace zilliz { namespace milvus { namespace scheduler { @@ -28,6 +26,5 @@ get_current_timestamp(); uint64_t get_num_gpu(); -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/action/Action.h b/cpp/src/scheduler/action/Action.h index a5f67aa98d05cbb8a3a9c804f0e395e82d930c3e..51c788f82f9d4ab91f8e0ffd874bede5efedd4fe 100644 --- a/cpp/src/scheduler/action/Action.h +++ b/cpp/src/scheduler/action/Action.h @@ -17,35 +17,32 @@ #pragma once -#include "scheduler/resource/Resource.h" #include "scheduler/ResourceMgr.h" +#include "scheduler/resource/Resource.h" #include -namespace zilliz { namespace milvus { namespace scheduler { class Action { public: static void - PushTaskToNeighbourRandomly(const TaskPtr &task, const ResourcePtr &self); + PushTaskToNeighbourRandomly(const TaskPtr& task, const ResourcePtr& self); static void - PushTaskToAllNeighbour(const TaskPtr &task, const ResourcePtr &self); + PushTaskToAllNeighbour(const TaskPtr& task, const ResourcePtr& self); static void - PushTaskToResource(const TaskPtr &task, const ResourcePtr &dest); + PushTaskToResource(const TaskPtr& task, const ResourcePtr& dest); static void DefaultLabelTaskScheduler(ResourceMgrWPtr res_mgr, ResourcePtr resource, std::shared_ptr event); static void - SpecifiedResourceLabelTaskScheduler(ResourceMgrWPtr res_mgr, - ResourcePtr resource, + SpecifiedResourceLabelTaskScheduler(ResourceMgrWPtr res_mgr, ResourcePtr resource, std::shared_ptr event); }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/action/PushTaskToNeighbour.cpp b/cpp/src/scheduler/action/PushTaskToNeighbour.cpp index 909112eb627c5a7c4d6e7d2b6f73a63656c95ae0..53dd45faca0f5c57fc6f1973128b4b1b0a978c1f 100644 --- a/cpp/src/scheduler/action/PushTaskToNeighbour.cpp +++ b/cpp/src/scheduler/action/PushTaskToNeighbour.cpp @@ -15,26 +15,26 @@ // specific language governing permissions and limitations // under the License. - #include #include #include "../Algorithm.h" -#include "src/cache/GpuCacheMgr.h" #include "Action.h" +#include "src/cache/GpuCacheMgr.h" +#include "src/server/Config.h" -namespace zilliz { namespace milvus { namespace scheduler { std::vector -get_neighbours(const ResourcePtr &self) { +get_neighbours(const ResourcePtr& self) { std::vector neighbours; - for (auto &neighbour_node : self->GetNeighbours()) { + for (auto& neighbour_node : self->GetNeighbours()) { auto node = neighbour_node.neighbour_node.lock(); - if (not node) continue; + if (not node) + continue; auto resource = std::static_pointer_cast(node); -// if (not resource->HasExecutor()) continue; + // if (not resource->HasExecutor()) continue; neighbours.emplace_back(resource); } @@ -42,14 +42,15 @@ get_neighbours(const ResourcePtr &self) { } std::vector> -get_neighbours_with_connetion(const ResourcePtr &self) { +get_neighbours_with_connetion(const ResourcePtr& self) { std::vector> neighbours; - for (auto &neighbour_node : self->GetNeighbours()) { + for (auto& neighbour_node : self->GetNeighbours()) { auto node = neighbour_node.neighbour_node.lock(); - if (not node) continue; + if (not node) + continue; auto resource = std::static_pointer_cast(node); -// if (not resource->HasExecutor()) continue; + // if (not resource->HasExecutor()) continue; Connection conn = neighbour_node.connection; neighbours.emplace_back(std::make_pair(resource, conn)); } @@ -57,13 +58,12 @@ get_neighbours_with_connetion(const ResourcePtr &self) { } void -Action::PushTaskToNeighbourRandomly(const TaskPtr &task, - const ResourcePtr &self) { +Action::PushTaskToNeighbourRandomly(const TaskPtr& task, const ResourcePtr& self) { auto neighbours = get_neighbours_with_connetion(self); if (not neighbours.empty()) { std::vector speeds; uint64_t total_speed = 0; - for (auto &neighbour : neighbours) { + for (auto& neighbour : neighbours) { uint64_t speed = neighbour.second.speed(); speeds.emplace_back(speed); total_speed += speed; @@ -83,38 +83,37 @@ Action::PushTaskToNeighbourRandomly(const TaskPtr &task, } } else { - //TODO: process + // TODO(wxyu): process } } void -Action::PushTaskToAllNeighbour(const TaskPtr &task, const ResourcePtr &self) { +Action::PushTaskToAllNeighbour(const TaskPtr& task, const ResourcePtr& self) { auto neighbours = get_neighbours(self); - for (auto &neighbour : neighbours) { + for (auto& neighbour : neighbours) { neighbour->task_table().Put(task); } } void -Action::PushTaskToResource(const TaskPtr &task, const ResourcePtr &dest) { +Action::PushTaskToResource(const TaskPtr& task, const ResourcePtr& dest) { dest->task_table().Put(task); } void -Action::DefaultLabelTaskScheduler(ResourceMgrWPtr res_mgr, - ResourcePtr resource, +Action::DefaultLabelTaskScheduler(ResourceMgrWPtr res_mgr, ResourcePtr resource, std::shared_ptr event) { if (not resource->HasExecutor() && event->task_table_item_->Move()) { auto task = event->task_table_item_->task; auto search_task = std::static_pointer_cast(task); bool moved = false; - //to support test task, REFACTOR + // to support test task, REFACTOR if (auto index_engine = search_task->index_engine_) { auto location = index_engine->GetLocation(); for (auto i = 0; i < res_mgr.lock()->GetNumGpuResource(); ++i) { - auto index = zilliz::milvus::cache::GpuCacheMgr::GetInstance(i)->GetIndex(location); + auto index = milvus::cache::GpuCacheMgr::GetInstance(i)->GetIndex(location); if (index != nullptr) { moved = true; auto dest_resource = res_mgr.lock()->GetResource(ResourceType::GPU, i); @@ -131,8 +130,7 @@ Action::DefaultLabelTaskScheduler(ResourceMgrWPtr res_mgr, } void -Action::SpecifiedResourceLabelTaskScheduler(ResourceMgrWPtr res_mgr, - ResourcePtr resource, +Action::SpecifiedResourceLabelTaskScheduler(ResourceMgrWPtr res_mgr, ResourcePtr resource, std::shared_ptr event) { auto task = event->task_table_item_->task; if (resource->type() == ResourceType::DISK) { @@ -140,32 +138,55 @@ Action::SpecifiedResourceLabelTaskScheduler(ResourceMgrWPtr res_mgr, auto compute_resources = res_mgr.lock()->GetComputeResources(); std::vector> paths; std::vector transport_costs; - for (auto &res : compute_resources) { + for (auto& res : compute_resources) { std::vector path; uint64_t transport_cost = ShortestPath(resource, res, res_mgr.lock(), path); transport_costs.push_back(transport_cost); paths.emplace_back(path); } + if (task->job_.lock()->type() == JobType::SEARCH) { + // step 2: select min cost, cost(resource) = avg_cost * task_to_do + transport_cost + uint64_t min_cost = std::numeric_limits::max(); + uint64_t min_cost_idx = 0; + for (uint64_t i = 0; i < compute_resources.size(); ++i) { + if (compute_resources[i]->TotalTasks() == 0) { + min_cost_idx = i; + break; + } + uint64_t cost = + compute_resources[i]->TaskAvgCost() * compute_resources[i]->NumOfTaskToExec() + transport_costs[i]; + if (min_cost > cost) { + min_cost = cost; + min_cost_idx = i; + } + } - // step 2: select min cost, cost(resource) = avg_cost * task_to_do + transport_cost - uint64_t min_cost = std::numeric_limits::max(); - uint64_t min_cost_idx = 0; - for (uint64_t i = 0; i < compute_resources.size(); ++i) { - if (compute_resources[i]->TotalTasks() == 0) { - min_cost_idx = i; - break; + // step 3: set path in task + Path task_path(paths[min_cost_idx], paths[min_cost_idx].size() - 1); + task->path() = task_path; + } else if (task->job_.lock()->type() == JobType::BUILD) { + // step2: Read device id in config + // get build index gpu resource + server::Config& config = server::Config::GetInstance(); + int32_t build_index_gpu; + Status stat = config.GetDBConfigBuildIndexGPU(build_index_gpu); + + bool find_gpu_res = false; + for (uint64_t i = 0; i < compute_resources.size(); ++i) { + if (res_mgr.lock()->GetResource(ResourceType::GPU, build_index_gpu) != nullptr) { + if (compute_resources[i]->name() == + res_mgr.lock()->GetResource(ResourceType::GPU, build_index_gpu)->name()) { + find_gpu_res = true; + Path task_path(paths[i], paths[i].size() - 1); + task->path() = task_path; + break; + } + } } - uint64_t cost = compute_resources[i]->TaskAvgCost() * compute_resources[i]->NumOfTaskToExec() - + transport_costs[i]; - if (min_cost > cost) { - min_cost = cost; - min_cost_idx = i; + if (not find_gpu_res) { + task->path() = Path(paths[0], paths[0].size() - 1); } } - - // step 3: set path in task - Path task_path(paths[min_cost_idx], paths[min_cost_idx].size() - 1); - task->path() = task_path; } if (resource->name() == task->path().Last()) { @@ -178,6 +199,5 @@ Action::SpecifiedResourceLabelTaskScheduler(ResourceMgrWPtr res_mgr, } } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/event/Event.h b/cpp/src/scheduler/event/Event.h index 860c60c5b727b600972e73f86d32dc8bbace3ce9..5b1f37fb9962c9a86d4fbce16627c5d58211bdc8 100644 --- a/cpp/src/scheduler/event/Event.h +++ b/cpp/src/scheduler/event/Event.h @@ -18,28 +18,19 @@ #pragma once #include -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { -enum class EventType { - START_UP, - LOAD_COMPLETED, - FINISH_TASK, - TASK_TABLE_UPDATED -}; +enum class EventType { START_UP, LOAD_COMPLETED, FINISH_TASK, TASK_TABLE_UPDATED }; class Resource; class Event { public: - explicit - Event(EventType type, std::weak_ptr resource) - : type_(type), - resource_(std::move(resource)) { + explicit Event(EventType type, std::weak_ptr resource) : type_(type), resource_(std::move(resource)) { } inline EventType @@ -50,7 +41,8 @@ class Event { virtual std::string Dump() const = 0; - friend std::ostream &operator<<(std::ostream &out, const Event &event); + friend std::ostream& + operator<<(std::ostream& out, const Event& event); public: EventType type_; @@ -59,6 +51,5 @@ class Event { using EventPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/event/EventDump.cpp b/cpp/src/scheduler/event/EventDump.cpp index a9ed751d88486bcf95950869be9a67eaf197de99..91e2da369aaf799084e92a390579e18fc6bba622 100644 --- a/cpp/src/scheduler/event/EventDump.cpp +++ b/cpp/src/scheduler/event/EventDump.cpp @@ -15,47 +15,44 @@ // specific language governing permissions and limitations // under the License. - #include "Event.h" -#include "StartUpEvent.h" -#include "LoadCompletedEvent.h" #include "FinishTaskEvent.h" +#include "LoadCompletedEvent.h" +#include "StartUpEvent.h" #include "TaskTableUpdatedEvent.h" -namespace zilliz { namespace milvus { namespace scheduler { -std::ostream & -operator<<(std::ostream &out, const Event &event) { +std::ostream& +operator<<(std::ostream& out, const Event& event) { out << event.Dump(); return out; } -std::ostream & -operator<<(std::ostream &out, const StartUpEvent &event) { +std::ostream& +operator<<(std::ostream& out, const StartUpEvent& event) { out << event.Dump(); return out; } -std::ostream & -operator<<(std::ostream &out, const LoadCompletedEvent &event) { +std::ostream& +operator<<(std::ostream& out, const LoadCompletedEvent& event) { out << event.Dump(); return out; } -std::ostream & -operator<<(std::ostream &out, const FinishTaskEvent &event) { +std::ostream& +operator<<(std::ostream& out, const FinishTaskEvent& event) { out << event.Dump(); return out; } -std::ostream & -operator<<(std::ostream &out, const TaskTableUpdatedEvent &event) { +std::ostream& +operator<<(std::ostream& out, const TaskTableUpdatedEvent& event) { out << event.Dump(); return out; } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/event/FinishTaskEvent.h b/cpp/src/scheduler/event/FinishTaskEvent.h index f49acb16ad4955a1c06aeecad5c3985fe8c65ec3..1b2d8f9818baa9c7fcf0eff18ca7bee149327b9b 100644 --- a/cpp/src/scheduler/event/FinishTaskEvent.h +++ b/cpp/src/scheduler/event/FinishTaskEvent.h @@ -17,21 +17,20 @@ #pragma once +#include "scheduler/TaskTable.h" #include "scheduler/event/Event.h" #include -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { class FinishTaskEvent : public Event { public: FinishTaskEvent(std::weak_ptr resource, TaskTableItemPtr task_table_item) - : Event(EventType::FINISH_TASK, std::move(resource)), - task_table_item_(std::move(task_table_item)) { + : Event(EventType::FINISH_TASK, std::move(resource)), task_table_item_(std::move(task_table_item)) { } inline std::string @@ -39,12 +38,12 @@ class FinishTaskEvent : public Event { return ""; } - friend std::ostream &operator<<(std::ostream &out, const FinishTaskEvent &event); + friend std::ostream& + operator<<(std::ostream& out, const FinishTaskEvent& event); public: TaskTableItemPtr task_table_item_; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/event/LoadCompletedEvent.h b/cpp/src/scheduler/event/LoadCompletedEvent.h index 8d727f7435984b3573d44fcab44fbbe21e964504..5a701e0dfc6cd005c8594ba3f425c2f2085161b9 100644 --- a/cpp/src/scheduler/event/LoadCompletedEvent.h +++ b/cpp/src/scheduler/event/LoadCompletedEvent.h @@ -17,22 +17,20 @@ #pragma once -#include "scheduler/event/Event.h" #include "scheduler/TaskTable.h" +#include "scheduler/event/Event.h" #include -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { class LoadCompletedEvent : public Event { public: LoadCompletedEvent(std::weak_ptr resource, TaskTableItemPtr task_table_item) - : Event(EventType::LOAD_COMPLETED, std::move(resource)), - task_table_item_(std::move(task_table_item)) { + : Event(EventType::LOAD_COMPLETED, std::move(resource)), task_table_item_(std::move(task_table_item)) { } inline std::string @@ -40,12 +38,12 @@ class LoadCompletedEvent : public Event { return ""; } - friend std::ostream &operator<<(std::ostream &out, const LoadCompletedEvent &event); + friend std::ostream& + operator<<(std::ostream& out, const LoadCompletedEvent& event); public: TaskTableItemPtr task_table_item_; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/event/StartUpEvent.h b/cpp/src/scheduler/event/StartUpEvent.h index 8e4ad120dee2f5c297e190cce5209c2e71d1e692..c4abb4e27ca85f7cda3fc12423cb9813d8891770 100644 --- a/cpp/src/scheduler/event/StartUpEvent.h +++ b/cpp/src/scheduler/event/StartUpEvent.h @@ -20,17 +20,15 @@ #include "scheduler/event/Event.h" #include -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { class StartUpEvent : public Event { public: - explicit StartUpEvent(std::weak_ptr resource) - : Event(EventType::START_UP, std::move(resource)) { + explicit StartUpEvent(std::weak_ptr resource) : Event(EventType::START_UP, std::move(resource)) { } inline std::string @@ -38,9 +36,9 @@ class StartUpEvent : public Event { return ""; } - friend std::ostream &operator<<(std::ostream &out, const StartUpEvent &event); + friend std::ostream& + operator<<(std::ostream& out, const StartUpEvent& event); }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/event/TaskTableUpdatedEvent.h b/cpp/src/scheduler/event/TaskTableUpdatedEvent.h index ec579b31bef4214d75d79bafe2f66ac81b246497..ed64a42d899bcf97edeb0c7d065c8eb7a341a1a9 100644 --- a/cpp/src/scheduler/event/TaskTableUpdatedEvent.h +++ b/cpp/src/scheduler/event/TaskTableUpdatedEvent.h @@ -20,10 +20,9 @@ #include "Event.h" #include -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -38,9 +37,9 @@ class TaskTableUpdatedEvent : public Event { return ""; } - friend std::ostream &operator<<(std::ostream &out, const TaskTableUpdatedEvent &event); + friend std::ostream& + operator<<(std::ostream& out, const TaskTableUpdatedEvent& event); }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/job/BuildIndexJob.cpp b/cpp/src/scheduler/job/BuildIndexJob.cpp new file mode 100644 index 0000000000000000000000000000000000000000..423121c5fb130654be90c4bb5713a29bbd576451 --- /dev/null +++ b/cpp/src/scheduler/job/BuildIndexJob.cpp @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "scheduler/job/BuildIndexJob.h" +#include "utils/Log.h" + +#include + +namespace milvus { +namespace scheduler { + +BuildIndexJob::BuildIndexJob(JobId id, engine::meta::MetaPtr meta_ptr, engine::DBOptions options) + : Job(id, JobType::BUILD), meta_ptr_(std::move(meta_ptr)), options_(std::move(options)) { +} + +bool +BuildIndexJob::AddToIndexFiles(const engine::meta::TableFileSchemaPtr& to_index_file) { + std::unique_lock lock(mutex_); + if (to_index_file == nullptr || to_index_files_.find(to_index_file->id_) != to_index_files_.end()) { + return false; + } + + SERVER_LOG_DEBUG << "BuildIndexJob " << id() << " add to_index file: " << to_index_file->id_; + + to_index_files_[to_index_file->id_] = to_index_file; +} + +Status& +BuildIndexJob::WaitBuildIndexFinish() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return to_index_files_.empty(); }); + SERVER_LOG_DEBUG << "BuildIndexJob " << id() << " all done"; +} + +void +BuildIndexJob::BuildIndexDone(size_t to_index_id) { + std::unique_lock lock(mutex_); + to_index_files_.erase(to_index_id); + cv_.notify_all(); + SERVER_LOG_DEBUG << "BuildIndexJob " << id() << " finish index file: " << to_index_id; +} + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/job/BuildIndexJob.h b/cpp/src/scheduler/job/BuildIndexJob.h new file mode 100644 index 0000000000000000000000000000000000000000..b6ca462537d787b691705c4c89ba8762f98030c7 --- /dev/null +++ b/cpp/src/scheduler/job/BuildIndexJob.h @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Job.h" +#include "db/meta/Meta.h" +#include "scheduler/Definition.h" + +namespace milvus { +namespace scheduler { + +using engine::meta::TableFileSchemaPtr; + +using Id2ToIndexMap = std::unordered_map; +using Id2ToTableFileMap = std::unordered_map; + +class BuildIndexJob : public Job { + public: + explicit BuildIndexJob(JobId id, engine::meta::MetaPtr meta_ptr, engine::DBOptions options); + + public: + bool + AddToIndexFiles(const TableFileSchemaPtr& to_index_file); + + Status& + WaitBuildIndexFinish(); + + void + BuildIndexDone(size_t to_index_id); + + public: + Status& + GetStatus() { + return status_; + } + + Id2ToIndexMap& + to_index_files() { + return to_index_files_; + } + + engine::meta::MetaPtr + meta() const { + return meta_ptr_; + } + + engine::DBOptions + options() const { + return options_; + } + + private: + Id2ToIndexMap to_index_files_; + engine::meta::MetaPtr meta_ptr_; + engine::DBOptions options_; + + Status status_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +using BuildIndexJobPtr = std::shared_ptr; + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/job/DeleteJob.cpp b/cpp/src/scheduler/job/DeleteJob.cpp index 9d917751c69616f4ffec2e4b8b2a0ed4cecaff7c..96a6bb18173eb1cd318d09ad5f9873ce637e474f 100644 --- a/cpp/src/scheduler/job/DeleteJob.cpp +++ b/cpp/src/scheduler/job/DeleteJob.cpp @@ -19,14 +19,10 @@ #include -namespace zilliz { namespace milvus { namespace scheduler { -DeleteJob::DeleteJob(JobId id, - std::string table_id, - engine::meta::MetaPtr meta_ptr, - uint64_t num_resource) +DeleteJob::DeleteJob(JobId id, std::string table_id, engine::meta::MetaPtr meta_ptr, uint64_t num_resource) : Job(id, JobType::DELETE), table_id_(std::move(table_id)), meta_ptr_(std::move(meta_ptr)), @@ -36,9 +32,7 @@ DeleteJob::DeleteJob(JobId id, void DeleteJob::WaitAndDelete() { std::unique_lock lock(mutex_); - cv_.wait(lock, [&] { - return done_resource == num_resource_; - }); + cv_.wait(lock, [&] { return done_resource == num_resource_; }); meta_ptr_->DeleteTableFiles(table_id_); } @@ -51,6 +45,5 @@ DeleteJob::ResourceDone() { cv_.notify_one(); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/job/DeleteJob.h b/cpp/src/scheduler/job/DeleteJob.h index 7d8b20e47c5df683ed3828b3720637ae34c20020..4ac48f6913cb1aa3e7cc476fd828cc01fc3a4a96 100644 --- a/cpp/src/scheduler/job/DeleteJob.h +++ b/cpp/src/scheduler/job/DeleteJob.h @@ -16,30 +16,26 @@ // under the License. #pragma once -#include -#include +#include +#include #include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include #include "Job.h" #include "db/meta/Meta.h" -namespace zilliz { namespace milvus { namespace scheduler { class DeleteJob : public Job { public: - DeleteJob(JobId id, - std::string table_id, - engine::meta::MetaPtr meta_ptr, - uint64_t num_resource); + DeleteJob(JobId id, std::string table_id, engine::meta::MetaPtr meta_ptr, uint64_t num_resource); public: void @@ -71,6 +67,5 @@ class DeleteJob : public Job { using DeleteJobPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/job/Job.h b/cpp/src/scheduler/job/Job.h index c646a4f034a85d1a20f77409d54bf91dde00964f..5fe645363fe87f0b5736c3401b48048f560b1e3f 100644 --- a/cpp/src/scheduler/job/Job.h +++ b/cpp/src/scheduler/job/Job.h @@ -16,18 +16,17 @@ // under the License. #pragma once -#include -#include +#include +#include #include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -64,6 +63,5 @@ class Job { using JobPtr = std::shared_ptr; using JobWPtr = std::weak_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/job/SearchJob.cpp b/cpp/src/scheduler/job/SearchJob.cpp index dee7125fed1de745d3762a784d445368f98ce6d0..518e3111c090ad0e7700da47b32ef9a9edb80364 100644 --- a/cpp/src/scheduler/job/SearchJob.cpp +++ b/cpp/src/scheduler/job/SearchJob.cpp @@ -18,24 +18,15 @@ #include "scheduler/job/SearchJob.h" #include "utils/Log.h" -namespace zilliz { namespace milvus { namespace scheduler { -SearchJob::SearchJob(zilliz::milvus::scheduler::JobId id, - uint64_t topk, - uint64_t nq, - uint64_t nprobe, - const float *vectors) - : Job(id, JobType::SEARCH), - topk_(topk), - nq_(nq), - nprobe_(nprobe), - vectors_(vectors) { +SearchJob::SearchJob(milvus::scheduler::JobId id, uint64_t topk, uint64_t nq, uint64_t nprobe, const float* vectors) + : Job(id, JobType::SEARCH), topk_(topk), nq_(nq), nprobe_(nprobe), vectors_(vectors) { } bool -SearchJob::AddIndexFile(const TableFileSchemaPtr &index_file) { +SearchJob::AddIndexFile(const TableFileSchemaPtr& index_file) { std::unique_lock lock(mutex_); if (index_file == nullptr || index_files_.find(index_file->id_) != index_files_.end()) { return false; @@ -50,9 +41,7 @@ SearchJob::AddIndexFile(const TableFileSchemaPtr &index_file) { void SearchJob::WaitResult() { std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { - return index_files_.empty(); - }); + cv_.wait(lock, [this] { return index_files_.empty(); }); SERVER_LOG_DEBUG << "SearchJob " << id() << " all done"; } @@ -64,16 +53,15 @@ SearchJob::SearchDone(size_t index_id) { SERVER_LOG_DEBUG << "SearchJob " << id() << " finish index file: " << index_id; } -ResultSet & +ResultSet& SearchJob::GetResult() { return result_; } -Status & +Status& SearchJob::GetStatus() { return status_; } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/job/SearchJob.h b/cpp/src/scheduler/job/SearchJob.h index 7bb7fbefbf4b24232bb826aa4a1909d900da662b..fb2d87d876e131cbe5df952c1b5ad75b6c3f4be8 100644 --- a/cpp/src/scheduler/job/SearchJob.h +++ b/cpp/src/scheduler/job/SearchJob.h @@ -16,38 +16,38 @@ // under the License. #pragma once -#include -#include +#include +#include #include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include #include +#include #include "Job.h" #include "db/meta/MetaTypes.h" -namespace zilliz { namespace milvus { namespace scheduler { using engine::meta::TableFileSchemaPtr; using Id2IndexMap = std::unordered_map; -using Id2DistanceMap = std::vector>; -using ResultSet = std::vector; +using IdDistPair = std::pair; +using Id2DistVec = std::vector; +using ResultSet = std::vector; class SearchJob : public Job { public: - SearchJob(JobId id, uint64_t topk, uint64_t nq, uint64_t nprobe, const float *vectors); + SearchJob(JobId id, uint64_t topk, uint64_t nq, uint64_t nprobe, const float* vectors); public: bool - AddIndexFile(const TableFileSchemaPtr &index_file); + AddIndexFile(const TableFileSchemaPtr& index_file); void WaitResult(); @@ -55,10 +55,10 @@ class SearchJob : public Job { void SearchDone(size_t index_id); - ResultSet & + ResultSet& GetResult(); - Status & + Status& GetStatus(); public: @@ -77,12 +77,12 @@ class SearchJob : public Job { return nprobe_; } - const float * + const float* vectors() const { return vectors_; } - Id2IndexMap & + Id2IndexMap& index_files() { return index_files_; } @@ -92,7 +92,7 @@ class SearchJob : public Job { uint64_t nq_ = 0; uint64_t nprobe_ = 0; // TODO: smart pointer - const float *vectors_ = nullptr; + const float* vectors_ = nullptr; Id2IndexMap index_files_; // TODO: column-base better ? @@ -105,6 +105,5 @@ class SearchJob : public Job { using SearchJobPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/optimizer/HybridPass.cpp b/cpp/src/scheduler/optimizer/HybridPass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f172a7beb96e565b0aa50122636b66112775bf5d --- /dev/null +++ b/cpp/src/scheduler/optimizer/HybridPass.cpp @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "scheduler/optimizer/HybridPass.h" +#include "scheduler/task/SearchTask.h" + +namespace milvus { +namespace scheduler { + +bool +HybridPass::Run(const TaskPtr& task) { + // TODO: Index::IVFSQ8Hybrid, if nq < threshold set cpu, else set gpu + if (task->Type() != TaskType::SearchTask) + return false; + auto search_task = std::static_pointer_cast(task); + // if (search_task->file_->engine_type_ == engine::EngineType::FAISS_IVFSQ8Hybrid) + return false; +} + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/optimizer/HybridPass.h b/cpp/src/scheduler/optimizer/HybridPass.h new file mode 100644 index 0000000000000000000000000000000000000000..0d02a8bda9926211a43d92fe30ccc17996f97c8a --- /dev/null +++ b/cpp/src/scheduler/optimizer/HybridPass.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Pass.h" + +namespace milvus { +namespace scheduler { + +class HybridPass : public Pass { + public: + HybridPass() = default; + + public: + bool + Run(const TaskPtr& task) override; +}; + +using HybridPassPtr = std::shared_ptr; + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/optimizer/Optimizer.cpp b/cpp/src/scheduler/optimizer/Optimizer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5fa311a27279f528e702ed01cee0be995e0703e --- /dev/null +++ b/cpp/src/scheduler/optimizer/Optimizer.cpp @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "scheduler/optimizer/Optimizer.h" + +namespace milvus { +namespace scheduler { + +void +Optimizer::Init() { + for (auto& pass : pass_list_) { + pass->Init(); + } +} + +bool +Optimizer::Run(const TaskPtr& task) { + for (auto& pass : pass_list_) { + if (pass->Run(task)) { + return true; + } + } + + return false; +} + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/optimizer/Optimizer.h b/cpp/src/scheduler/optimizer/Optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..99282e66a679d5480661c1652cf85eaee1151870 --- /dev/null +++ b/cpp/src/scheduler/optimizer/Optimizer.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Pass.h" + +namespace milvus { +namespace scheduler { + +class Optimizer { + public: + Optimizer() = default; + + void + Init(); + + bool + Run(const TaskPtr& task); + + private: + std::vector pass_list_; +}; + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/optimizer/Pass.h b/cpp/src/scheduler/optimizer/Pass.h new file mode 100644 index 0000000000000000000000000000000000000000..959c3ea5ee83d63b2b8a044eeb7a102654826421 --- /dev/null +++ b/cpp/src/scheduler/optimizer/Pass.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "scheduler/task/Task.h" + +namespace milvus { +namespace scheduler { + +class Pass { + public: + virtual void + Init() { + } + + virtual bool + Run(const TaskPtr& task) = 0; +}; +using PassPtr = std::shared_ptr; + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/Connection.h b/cpp/src/scheduler/resource/Connection.h index cf18b6c9a214c8947fbe060c7c61c51ba0832650..421d32f768a954909a5bff7aec35cd1f95a4e1fb 100644 --- a/cpp/src/scheduler/resource/Connection.h +++ b/cpp/src/scheduler/resource/Connection.h @@ -17,22 +17,20 @@ #pragma once -#include #include +#include #include -namespace zilliz { namespace milvus { namespace scheduler { class Connection { public: // TODO: update construct function, speed: double->uint64_t - Connection(std::string name, double speed) - : name_(std::move(name)), speed_(speed) { + Connection(std::string name, double speed) : name_(std::move(name)), speed_(speed) { } - const std::string & + const std::string& name() const { return name_; } @@ -60,6 +58,5 @@ class Connection { uint64_t speed_; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/CpuResource.cpp b/cpp/src/scheduler/resource/CpuResource.cpp index 5859dfd0cddcc51f985e1467ea1e973490d19b1b..500737a829f2e13b9a7eff1f4505e706943a4174 100644 --- a/cpp/src/scheduler/resource/CpuResource.cpp +++ b/cpp/src/scheduler/resource/CpuResource.cpp @@ -15,17 +15,15 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/resource/CpuResource.h" #include -namespace zilliz { namespace milvus { namespace scheduler { -std::ostream & -operator<<(std::ostream &out, const CpuResource &resource) { +std::ostream& +operator<<(std::ostream& out, const CpuResource& resource) { out << resource.Dump(); return out; } @@ -44,6 +42,5 @@ CpuResource::Process(TaskPtr task) { task->Execute(); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/CpuResource.h b/cpp/src/scheduler/resource/CpuResource.h index 2226523fdfe55b8320ef950001e08f4d5a737257..e3e4fc383f8a74f9f8b5599c40ba737ac7c3b532 100644 --- a/cpp/src/scheduler/resource/CpuResource.h +++ b/cpp/src/scheduler/resource/CpuResource.h @@ -21,21 +21,20 @@ #include "Resource.h" -namespace zilliz { namespace milvus { namespace scheduler { class CpuResource : public Resource { public: - explicit - CpuResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); + explicit CpuResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); inline std::string Dump() const override { return ""; } - friend std::ostream &operator<<(std::ostream &out, const CpuResource &resource); + friend std::ostream& + operator<<(std::ostream& out, const CpuResource& resource); protected: void @@ -45,6 +44,5 @@ class CpuResource : public Resource { Process(TaskPtr task) override; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/DiskResource.cpp b/cpp/src/scheduler/resource/DiskResource.cpp index eee2424cc11ed16a94f04764dc7c08ce5800d4d1..fe1fc9c8d978b650e05ffd61194f5474875dd2d1 100644 --- a/cpp/src/scheduler/resource/DiskResource.cpp +++ b/cpp/src/scheduler/resource/DiskResource.cpp @@ -20,12 +20,11 @@ #include #include -namespace zilliz { namespace milvus { namespace scheduler { -std::ostream & -operator<<(std::ostream &out, const DiskResource &resource) { +std::ostream& +operator<<(std::ostream& out, const DiskResource& resource) { out << resource.Dump(); return out; } @@ -42,6 +41,5 @@ void DiskResource::Process(TaskPtr task) { } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/DiskResource.h b/cpp/src/scheduler/resource/DiskResource.h index a7caf5c6622e97e0e8ed96a42b6907a1203439db..2346cd115a849ef849d0a249ba53e217d4753374 100644 --- a/cpp/src/scheduler/resource/DiskResource.h +++ b/cpp/src/scheduler/resource/DiskResource.h @@ -21,21 +21,20 @@ #include -namespace zilliz { namespace milvus { namespace scheduler { class DiskResource : public Resource { public: - explicit - DiskResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); + explicit DiskResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); inline std::string Dump() const override { return ""; } - friend std::ostream &operator<<(std::ostream &out, const DiskResource &resource); + friend std::ostream& + operator<<(std::ostream& out, const DiskResource& resource); protected: void @@ -45,6 +44,5 @@ class DiskResource : public Resource { Process(TaskPtr task) override; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/GpuResource.cpp b/cpp/src/scheduler/resource/GpuResource.cpp index 3c7abc0b2937f14422a5f426c53b8c0671541ebb..20ed73e38c8d2580c042fe6e06bea95672091a5b 100644 --- a/cpp/src/scheduler/resource/GpuResource.cpp +++ b/cpp/src/scheduler/resource/GpuResource.cpp @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/resource/GpuResource.h" -namespace zilliz { namespace milvus { namespace scheduler { -std::ostream & -operator<<(std::ostream &out, const GpuResource &resource) { +std::ostream& +operator<<(std::ostream& out, const GpuResource& resource) { out << resource.Dump(); return out; } @@ -42,6 +40,5 @@ GpuResource::Process(TaskPtr task) { task->Execute(); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/GpuResource.h b/cpp/src/scheduler/resource/GpuResource.h index 9f19b07464463f6c6e49c53f9d203631f3a6b674..e0df03d5a78b6ad9915d39062b9fb3762b84f852 100644 --- a/cpp/src/scheduler/resource/GpuResource.h +++ b/cpp/src/scheduler/resource/GpuResource.h @@ -22,21 +22,20 @@ #include #include -namespace zilliz { namespace milvus { namespace scheduler { class GpuResource : public Resource { public: - explicit - GpuResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); + explicit GpuResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); inline std::string Dump() const override { return ""; } - friend std::ostream &operator<<(std::ostream &out, const GpuResource &resource); + friend std::ostream& + operator<<(std::ostream& out, const GpuResource& resource); protected: void @@ -46,6 +45,5 @@ class GpuResource : public Resource { Process(TaskPtr task) override; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/Node.cpp b/cpp/src/scheduler/resource/Node.cpp index cf652b8ba0ac21d73bee593bdea39a45ca605c52..5401c364418d8aa08182010d6b4a6b66095f93d1 100644 --- a/cpp/src/scheduler/resource/Node.cpp +++ b/cpp/src/scheduler/resource/Node.cpp @@ -20,7 +20,6 @@ #include #include -namespace zilliz { namespace milvus { namespace scheduler { @@ -33,7 +32,7 @@ std::vector Node::GetNeighbours() { std::lock_guard lk(mutex_); std::vector ret; - for (auto &e : neighbours_) { + for (auto& e : neighbours_) { ret.push_back(e.second); } return ret; @@ -43,7 +42,7 @@ std::string Node::Dump() { std::stringstream ss; ss << "::neighbours:" << std::endl; - for (auto &neighbour : neighbours_) { + for (auto& neighbour : neighbours_) { ss << "\t" << std::endl; } @@ -51,7 +50,7 @@ Node::Dump() { } void -Node::AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection) { +Node::AddNeighbour(const NeighbourNodePtr& neighbour_node, Connection& connection) { std::lock_guard lk(mutex_); if (auto s = neighbour_node.lock()) { neighbours_.emplace(std::make_pair(s->id_, Neighbour(neighbour_node, connection))); @@ -59,6 +58,5 @@ Node::AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connectio // else do nothing, consider it.. } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/Node.h b/cpp/src/scheduler/resource/Node.h index 00337a70094ae3cc2b399e2e4c47d83662c498bb..071ee9bab82121954af26040bf2cc9d0ece730af 100644 --- a/cpp/src/scheduler/resource/Node.h +++ b/cpp/src/scheduler/resource/Node.h @@ -17,15 +17,14 @@ #pragma once -#include -#include #include +#include #include +#include -#include "scheduler/TaskTable.h" #include "Connection.h" +#include "scheduler/TaskTable.h" -namespace zilliz { namespace milvus { namespace scheduler { @@ -34,21 +33,20 @@ class Node; using NeighbourNodePtr = std::weak_ptr; struct Neighbour { - Neighbour(NeighbourNodePtr nei, Connection conn) - : neighbour_node(nei), connection(conn) { + Neighbour(NeighbourNodePtr nei, Connection conn) : neighbour_node(nei), connection(conn) { } NeighbourNodePtr neighbour_node; Connection connection; }; -// TODO(linxj): return type void -> Status +// TODO(lxj): return type void -> Status class Node { public: Node(); void - AddNeighbour(const NeighbourNodePtr &neighbour_node, Connection &connection); + AddNeighbour(const NeighbourNodePtr& neighbour_node, Connection& connection); std::vector GetNeighbours(); @@ -66,6 +64,5 @@ class Node { using NodePtr = std::shared_ptr; using NodeWPtr = std::weak_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/Resource.cpp b/cpp/src/scheduler/resource/Resource.cpp index 4eb71b8ac52a97b967acbb493c86afa4137be487..8fea475d70b834f92aa66e49e33f1f8848ed7940 100644 --- a/cpp/src/scheduler/resource/Resource.cpp +++ b/cpp/src/scheduler/resource/Resource.cpp @@ -21,21 +21,16 @@ #include #include -namespace zilliz { namespace milvus { namespace scheduler { -std::ostream & -operator<<(std::ostream &out, const Resource &resource) { +std::ostream& +operator<<(std::ostream& out, const Resource& resource) { out << resource.Dump(); return out; } -Resource::Resource(std::string name, - ResourceType type, - uint64_t device_id, - bool enable_loader, - bool enable_executor) +Resource::Resource(std::string name, ResourceType type, uint64_t device_id, bool enable_loader, bool enable_executor) : name_(std::move(name)), type_(type), device_id_(device_id), @@ -95,8 +90,9 @@ Resource::WakeupExecutor() { uint64_t Resource::NumOfTaskToExec() { uint64_t count = 0; - for (auto &task : task_table_) { - if (task->state == TaskTableItemState::LOADED) ++count; + for (auto& task : task_table_) { + if (task->state == TaskTableItemState::LOADED) + ++count; } return count; } @@ -129,9 +125,7 @@ void Resource::loader_function() { while (running_) { std::unique_lock lock(load_mutex_); - load_cv_.wait(lock, [&] { - return load_flag_; - }); + load_cv_.wait(lock, [&] { return load_flag_; }); load_flag_ = false; lock.unlock(); while (true) { @@ -157,9 +151,7 @@ Resource::executor_function() { } while (running_) { std::unique_lock lock(exec_mutex_); - exec_cv_.wait(lock, [&] { - return exec_flag_; - }); + exec_cv_.wait(lock, [&] { return exec_flag_; }); exec_flag_ = false; lock.unlock(); while (true) { @@ -183,6 +175,5 @@ Resource::executor_function() { } } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/Resource.h b/cpp/src/scheduler/resource/Resource.h index 1c18b1a2b2497502d41ccec24acde446f22cc60f..c9026f13b6b0279e7add22110a1b40189f55845e 100644 --- a/cpp/src/scheduler/resource/Resource.h +++ b/cpp/src/scheduler/resource/Resource.h @@ -17,25 +17,24 @@ #pragma once -#include -#include +#include +#include #include -#include +#include #include -#include -#include +#include +#include +#include "../TaskTable.h" #include "../event/Event.h" -#include "../event/StartUpEvent.h" -#include "../event/LoadCompletedEvent.h" #include "../event/FinishTaskEvent.h" +#include "../event/LoadCompletedEvent.h" +#include "../event/StartUpEvent.h" #include "../event/TaskTableUpdatedEvent.h" -#include "../TaskTable.h" #include "../task/Task.h" #include "Connection.h" #include "Node.h" -namespace zilliz { namespace milvus { namespace scheduler { @@ -99,7 +98,7 @@ class Resource : public Node, public std::enable_shared_from_this { return device_id_; } - TaskTable & + TaskTable& task_table() { return task_table_; } @@ -115,11 +114,11 @@ class Resource : public Node, public std::enable_shared_from_this { return enable_executor_; } - // TODO: const + // TODO(wxyu): const uint64_t NumOfTaskToExec(); - // TODO: need double ? + // TODO(wxyu): need double ? inline uint64_t TaskAvgCost() const { return total_cost_ / total_task_; @@ -130,14 +129,11 @@ class Resource : public Node, public std::enable_shared_from_this { return total_task_; } - friend std::ostream &operator<<(std::ostream &out, const Resource &resource); + friend std::ostream& + operator<<(std::ostream& out, const Resource& resource); protected: - Resource(std::string name, - ResourceType type, - uint64_t device_id, - bool enable_loader, - bool enable_executor); + Resource(std::string name, ResourceType type, uint64_t device_id, bool enable_loader, bool enable_executor); /* * Implementation by inherit class; @@ -212,6 +208,5 @@ class Resource : public Node, public std::enable_shared_from_this { using ResourcePtr = std::shared_ptr; using ResourceWPtr = std::weak_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/TestResource.cpp b/cpp/src/scheduler/resource/TestResource.cpp index 25560cf7ee65a05b2c9a577a34ca0f07d93b7cae..c8c2fb7537180518b851f34cc3f70c0fac5ef64d 100644 --- a/cpp/src/scheduler/resource/TestResource.cpp +++ b/cpp/src/scheduler/resource/TestResource.cpp @@ -19,12 +19,11 @@ #include -namespace zilliz { namespace milvus { namespace scheduler { -std::ostream & -operator<<(std::ostream &out, const TestResource &resource) { +std::ostream& +operator<<(std::ostream& out, const TestResource& resource) { out << resource.Dump(); return out; } @@ -43,6 +42,5 @@ TestResource::Process(TaskPtr task) { task->Execute(); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/resource/TestResource.h b/cpp/src/scheduler/resource/TestResource.h index ac83a42c60c776fd364611e136bf0afb21fa0b87..9bbc5a54d0199d6a1332bc6bc36d83e536eac5b1 100644 --- a/cpp/src/scheduler/resource/TestResource.h +++ b/cpp/src/scheduler/resource/TestResource.h @@ -19,24 +19,23 @@ #include "Resource.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { class TestResource : public Resource { public: - explicit - TestResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); + explicit TestResource(std::string name, uint64_t device_id, bool enable_loader, bool enable_executor); inline std::string Dump() const override { return ""; } - friend std::ostream &operator<<(std::ostream &out, const TestResource &resource); + friend std::ostream& + operator<<(std::ostream& out, const TestResource& resource); protected: void @@ -46,6 +45,5 @@ class TestResource : public Resource { Process(TaskPtr task) override; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/BuildIndexTask.cpp b/cpp/src/scheduler/task/BuildIndexTask.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f2cebcac9e24b9b36d6003ad19ca20242cc0715c --- /dev/null +++ b/cpp/src/scheduler/task/BuildIndexTask.cpp @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "scheduler/task/BuildIndexTask.h" +#include "db/engine/EngineFactory.h" +#include "metrics/Metrics.h" +#include "scheduler/job/BuildIndexJob.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +#include +#include +#include +#include + +namespace milvus { +namespace scheduler { + +XBuildIndexTask::XBuildIndexTask(TableFileSchemaPtr file, TaskLabelPtr label) + : Task(TaskType::BuildIndexTask, std::move(label)), file_(file) { + if (file_) { + to_index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, (EngineType)file_->engine_type_, + (MetricType)file_->metric_type_, file_->nlist_); + } +} + +void +XBuildIndexTask::Load(milvus::scheduler::LoadType type, uint8_t device_id) { + TimeRecorder rc(""); + Status stat = Status::OK(); + std::string error_msg; + std::string type_str; + + if (auto job = job_.lock()) { + auto build_index_job = std::static_pointer_cast(job); + auto options = build_index_job->options(); + try { + if (type == LoadType::DISK2CPU) { + stat = to_index_engine_->Load(options.insert_cache_immediately_); + type_str = "DISK2CPU"; + } else if (type == LoadType::CPU2GPU) { + stat = to_index_engine_->CopyToIndexFileToGpu(device_id); + type_str = "CPU2GPU"; + } else if (type == LoadType::GPU2CPU) { + stat = to_index_engine_->CopyToCpu(); + type_str = "GPU2CPU"; + } else { + error_msg = "Wrong load type"; + stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + } catch (std::exception& ex) { + // typical error: out of disk space or permition denied + error_msg = "Failed to load to_index file: " + std::string(ex.what()); + stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + + if (!stat.ok()) { + Status s; + if (stat.ToString().find("out of memory") != std::string::npos) { + error_msg = "out of memory: " + type_str; + s = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } else { + error_msg = "Failed to load to_index file: " + type_str; + s = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + + if (auto job = job_.lock()) { + auto build_index_job = std::static_pointer_cast(job); + build_index_job->BuildIndexDone(file_->id_); + } + + return; + } + + size_t file_size = to_index_engine_->PhysicalSize(); + + std::string info = "Load file id:" + std::to_string(file_->id_) + + " file type:" + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) + + " bytes from location: " + file_->location_ + " totally cost"; + double span = rc.ElapseFromBegin(info); + + to_index_id_ = file_->id_; + to_index_type_ = file_->file_type_; + } +} + +void +XBuildIndexTask::Execute() { + if (to_index_engine_ == nullptr) { + return; + } + + TimeRecorder rc("DoBuildIndex file id:" + std::to_string(to_index_id_)); + + if (auto job = job_.lock()) { + auto build_index_job = std::static_pointer_cast(job); + std::string location = file_->location_; + EngineType engine_type = (EngineType)file_->engine_type_; + std::shared_ptr index; + + // step 2: create table file + engine::meta::TableFileSchema table_file; + table_file.table_id_ = file_->table_id_; + table_file.date_ = file_->date_; + table_file.file_type_ = engine::meta::TableFileSchema::NEW_INDEX; + + engine::meta::MetaPtr meta_ptr = build_index_job->meta(); + Status status = build_index_job->meta()->CreateTableFile(table_file); + if (!status.ok()) { + ENGINE_LOG_ERROR << "Failed to create table file: " << status.ToString(); + build_index_job->BuildIndexDone(to_index_id_); + build_index_job->GetStatus() = status; + return; + } + + // step 3: build index + try { + index = to_index_engine_->BuildIndex(table_file.location_, (EngineType)table_file.engine_type_); + if (index == nullptr) { + table_file.file_type_ = engine::meta::TableFileSchema::TO_DELETE; + status = meta_ptr->UpdateTableFile(table_file); + ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ + << " to to_delete"; + + return; + } + } catch (std::exception& ex) { + std::string msg = "BuildIndex encounter exception: " + std::string(ex.what()); + ENGINE_LOG_ERROR << msg; + + table_file.file_type_ = engine::meta::TableFileSchema::TO_DELETE; + status = meta_ptr->UpdateTableFile(table_file); + ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete"; + + std::cout << "ERROR: failed to build index, index file is too large or gpu memory is not enough" + << std::endl; + + build_index_job->GetStatus() = Status(DB_ERROR, msg); + return; + } + + // step 4: if table has been deleted, dont save index file + bool has_table = false; + meta_ptr->HasTable(file_->table_id_, has_table); + if (!has_table) { + meta_ptr->DeleteTableFiles(file_->table_id_); + return; + } + + // step 5: save index file + try { + index->Serialize(); + } catch (std::exception& ex) { + // typical error: out of disk space or permition denied + std::string msg = "Serialize index encounter exception: " + std::string(ex.what()); + ENGINE_LOG_ERROR << msg; + + table_file.file_type_ = engine::meta::TableFileSchema::TO_DELETE; + status = meta_ptr->UpdateTableFile(table_file); + ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete"; + + std::cout << "ERROR: failed to persist index file: " << table_file.location_ + << ", possible out of disk space" << std::endl; + + build_index_job->GetStatus() = Status(DB_ERROR, msg); + return; + } + + // step 6: update meta + table_file.file_type_ = engine::meta::TableFileSchema::INDEX; + table_file.file_size_ = index->PhysicalSize(); + table_file.row_count_ = index->Count(); + + auto origin_file = *file_; + origin_file.file_type_ = engine::meta::TableFileSchema::BACKUP; + + engine::meta::TableFilesSchema update_files = {table_file, origin_file}; + status = meta_ptr->UpdateTableFiles(update_files); + if (status.ok()) { + ENGINE_LOG_DEBUG << "New index file " << table_file.file_id_ << " of size " << index->PhysicalSize() + << " bytes" + << " from file " << origin_file.file_id_; + + // index->Cache(); + } else { + // failed to update meta, mark the new file as to_delete, don't delete old file + origin_file.file_type_ = engine::meta::TableFileSchema::TO_INDEX; + status = meta_ptr->UpdateTableFile(origin_file); + ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << origin_file.file_id_ << " to to_index"; + + table_file.file_type_ = engine::meta::TableFileSchema::TO_DELETE; + status = meta_ptr->UpdateTableFile(table_file); + ENGINE_LOG_DEBUG << "Failed to up date file to index, mark file: " << table_file.file_id_ + << " to to_delete"; + } + + build_index_job->BuildIndexDone(to_index_id_); + } + + rc.ElapseFromBegin("totally cost"); + + to_index_engine_ = nullptr; +} + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/BuildIndexTask.h b/cpp/src/scheduler/task/BuildIndexTask.h new file mode 100644 index 0000000000000000000000000000000000000000..5c2aa69a009b73d30a76b1ed165ff2d9450493ce --- /dev/null +++ b/cpp/src/scheduler/task/BuildIndexTask.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "Task.h" +#include "scheduler/Definition.h" +#include "scheduler/job/BuildIndexJob.h" + +namespace milvus { +namespace scheduler { + +class XBuildIndexTask : public Task { + public: + explicit XBuildIndexTask(TableFileSchemaPtr file, TaskLabelPtr label); + + void + Load(LoadType type, uint8_t device_id) override; + + void + Execute() override; + + public: + TableFileSchemaPtr file_; + TableFileSchema table_file_; + size_t to_index_id_ = 0; + int to_index_type_ = 0; + ExecutionEnginePtr to_index_engine_ = nullptr; +}; + +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/DeleteTask.cpp b/cpp/src/scheduler/task/DeleteTask.cpp index 52579d67c67b5edd2061a84702703f5b8e599e03..bffe78cf8ff528df6a07c1f600710b634362ebf7 100644 --- a/cpp/src/scheduler/task/DeleteTask.cpp +++ b/cpp/src/scheduler/task/DeleteTask.cpp @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. - #include "scheduler/task/DeleteTask.h" -namespace zilliz { +#include + namespace milvus { namespace scheduler { -XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr &delete_job) - : Task(TaskType::DeleteTask), delete_job_(delete_job) { +XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label) + : Task(TaskType::DeleteTask, std::move(label)), delete_job_(delete_job) { } void @@ -35,6 +35,5 @@ XDeleteTask::Execute() { delete_job_->ResourceDone(); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/DeleteTask.h b/cpp/src/scheduler/task/DeleteTask.h index 608960e7c868b612dfc121a061e32bc9c28f34e0..fd5222ba4e19a592dcb4e4f3e95bfe0e7441d86a 100644 --- a/cpp/src/scheduler/task/DeleteTask.h +++ b/cpp/src/scheduler/task/DeleteTask.h @@ -17,16 +17,15 @@ #pragma once -#include "scheduler/job/DeleteJob.h" #include "Task.h" +#include "scheduler/job/DeleteJob.h" -namespace zilliz { namespace milvus { namespace scheduler { class XDeleteTask : public Task { public: - explicit XDeleteTask(const scheduler::DeleteJobPtr &job); + explicit XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label); void Load(LoadType type, uint8_t device_id) override; @@ -38,6 +37,5 @@ class XDeleteTask : public Task { scheduler::DeleteJobPtr delete_job_; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/Path.h b/cpp/src/scheduler/task/Path.h index 672dfff1b9b36c251f4f23ff4a6d8972050dbe4b..c23db9bb09967382042efa7208cae772397f6d79 100644 --- a/cpp/src/scheduler/task/Path.h +++ b/cpp/src/scheduler/task/Path.h @@ -17,10 +17,9 @@ #pragma once -#include #include +#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -28,11 +27,11 @@ class Path { public: Path() = default; - Path(std::vector &path, uint64_t index) : path_(path), index_(index) { + Path(std::vector& path, uint64_t index) : path_(path), index_(index) { } void - push_back(const std::string &str) { + push_back(const std::string& str) { path_.push_back(str); } @@ -61,16 +60,17 @@ class Path { } public: - std::string & - operator[](uint64_t index) { + std::string& operator[](uint64_t index) { return path_[index]; } - std::vector::iterator begin() { + std::vector::iterator + begin() { return path_.begin(); } - std::vector::iterator end() { + std::vector::iterator + end() { return path_.end(); } @@ -79,6 +79,5 @@ class Path { uint64_t index_ = 0; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/SearchTask.cpp b/cpp/src/scheduler/task/SearchTask.cpp index 0c205fcafabcb2cfcd201a59f37fc72221c6e226..9925a8bcf85511d83759ea2fd88aacca1ba7f44a 100644 --- a/cpp/src/scheduler/task/SearchTask.cpp +++ b/cpp/src/scheduler/task/SearchTask.cpp @@ -16,17 +16,17 @@ // under the License. #include "scheduler/task/SearchTask.h" -#include "scheduler/job/SearchJob.h" -#include "metrics/Metrics.h" #include "db/engine/EngineFactory.h" -#include "utils/TimeRecorder.h" +#include "metrics/Metrics.h" +#include "scheduler/job/SearchJob.h" #include "utils/Log.h" +#include "utils/TimeRecorder.h" +#include +#include #include #include -#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -35,8 +35,9 @@ static constexpr size_t PARALLEL_REDUCE_BATCH = 1000; std::mutex XSearchTask::merge_mutex_; -//bool -//NeedParallelReduce(uint64_t nq, uint64_t topk) { +// TODO(wxyu): remove unused code +// bool +// NeedParallelReduce(uint64_t nq, uint64_t topk) { // server::ServerConfig &config = server::ServerConfig::GetInstance(); // server::ConfigNode &db_config = config.GetConfig(server::CONFIG_DB); // bool need_parallel = db_config.GetBoolValue(server::CONFIG_DB_PARALLEL_REDUCE, false); @@ -47,8 +48,8 @@ std::mutex XSearchTask::merge_mutex_; // return nq * topk >= PARALLEL_REDUCE_THRESHOLD; //} // -//void -//ParallelReduce(std::function &reduce_function, size_t max_index) { +// void +// ParallelReduce(std::function &reduce_function, size_t max_index) { // size_t reduce_batch = PARALLEL_REDUCE_BATCH; // // auto thread_count = std::thread::hardware_concurrency() - 1; //not all core do this work @@ -79,31 +80,32 @@ std::mutex XSearchTask::merge_mutex_; void CollectFileMetrics(int file_type, size_t file_size) { + server::MetricsBase& inst = server::Metrics::GetInstance(); switch (file_type) { case TableFileSchema::RAW: case TableFileSchema::TO_INDEX: { - server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size); - server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size); - server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size); + inst.RawFileSizeHistogramObserve(file_size); + inst.RawFileSizeTotalIncrement(file_size); + inst.RawFileSizeGaugeSet(file_size); break; } default: { - server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size); - server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size); - server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size); + inst.IndexFileSizeHistogramObserve(file_size); + inst.IndexFileSizeTotalIncrement(file_size); + inst.IndexFileSizeGaugeSet(file_size); break; } } } -XSearchTask::XSearchTask(TableFileSchemaPtr file) - : Task(TaskType::SearchTask), file_(file) { +XSearchTask::XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label) + : Task(TaskType::SearchTask, std::move(label)), file_(file) { if (file_) { - index_engine_ = EngineFactory::Build(file_->dimension_, - file_->location_, - (EngineType) file_->engine_type_, - (MetricType) file_->metric_type_, - file_->nlist_); + if (file_->metric_type_ != static_cast(MetricType::L2)) { + metric_l2 = false; + } + index_engine_ = EngineFactory::Build(file_->dimension_, file_->location_, (EngineType)file_->engine_type_, + (MetricType)file_->metric_type_, file_->nlist_); } } @@ -128,8 +130,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) { error_msg = "Wrong load type"; stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); } - } catch (std::exception &ex) { - //typical error: out of disk space or permition denied + } catch (std::exception& ex) { + // typical error: out of disk space or permition denied error_msg = "Failed to load index file: " + std::string(ex.what()); stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); } @@ -155,19 +157,20 @@ XSearchTask::Load(LoadType type, uint8_t device_id) { size_t file_size = index_engine_->PhysicalSize(); - std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" + std::to_string(file_->file_type_) - + " size:" + std::to_string(file_size) + " bytes from location: " + file_->location_ + " totally cost"; + std::string info = "Load file id:" + std::to_string(file_->id_) + + " file type:" + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) + + " bytes from location: " + file_->location_ + " totally cost"; double span = rc.ElapseFromBegin(info); -// for (auto &context : search_contexts_) { -// context->AccumLoadCost(span); -// } + // for (auto &context : search_contexts_) { + // context->AccumLoadCost(span); + // } CollectFileMetrics(file_->file_type_, file_size); - //step 2: return search task for later execution + // step 2: return search task for later execution index_id_ = file_->id_; index_type_ = file_->file_type_; -// search_contexts_.swap(search_contexts_); + // search_contexts_.swap(search_contexts_); } void @@ -176,8 +179,8 @@ XSearchTask::Execute() { return; } -// ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_ << " with " -// << search_contexts_.size() << " tasks"; + // ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_ << " with " + // << search_contexts_.size() << " tasks"; TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_)); @@ -188,45 +191,36 @@ XSearchTask::Execute() { if (auto job = job_.lock()) { auto search_job = std::static_pointer_cast(job); - //step 1: allocate memory + // step 1: allocate memory uint64_t nq = search_job->nq(); uint64_t topk = search_job->topk(); uint64_t nprobe = search_job->nprobe(); - const float *vectors = search_job->vectors(); + const float* vectors = search_job->vectors(); output_ids.resize(topk * nq); output_distance.resize(topk * nq); - std::string hdr = "job " + std::to_string(search_job->id()) + - " nq " + std::to_string(nq) + - " topk " + std::to_string(topk); + std::string hdr = + "job " + std::to_string(search_job->id()) + " nq " + std::to_string(nq) + " topk " + std::to_string(topk); try { - //step 2: search + // step 2: search index_engine_->Search(nq, vectors, topk, nprobe, output_distance.data(), output_ids.data()); double span = rc.RecordSection(hdr + ", do search"); -// search_job->AccumSearchCost(span); + // search_job->AccumSearchCost(span); - - //step 3: cluster result - scheduler::ResultSet result_set; + // step 3: pick up topk result auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; - XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set); - - span = rc.RecordSection(hdr + ", cluster result"); -// search_job->AccumReduceCost(span); - - // step 4: pick up topk result - XSearchTask::TopkResult(result_set, topk, metric_l2, search_job->GetResult()); + XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); span = rc.RecordSection(hdr + ", reduce topk"); -// search_job->AccumReduceCost(span); - } catch (std::exception &ex) { + // search_job->AccumReduceCost(span); + } catch (std::exception& ex) { ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what(); -// search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed + // search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed } - //step 5: notify to send result to client + // step 5: notify to send result to client search_job->SearchDone(index_id_); } @@ -237,153 +231,73 @@ XSearchTask::Execute() { } Status -XSearchTask::ClusterResult(const std::vector &output_ids, - const std::vector &output_distance, - uint64_t nq, - uint64_t topk, - scheduler::ResultSet &result_set) { - if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) { - std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + - " distance array size: " + std::to_string(output_distance.size()); - ENGINE_LOG_ERROR << msg; - return Status(DB_ERROR, msg); - } - - result_set.clear(); - result_set.resize(nq); - - std::function reduce_worker = [&](size_t from_index, size_t to_index) { - for (auto i = from_index; i < to_index; i++) { - scheduler::Id2DistanceMap id_distance; - id_distance.reserve(topk); - for (auto k = 0; k < topk; k++) { - uint64_t index = i * topk + k; - if (output_ids[index] < 0) { - continue; - } - id_distance.push_back(std::make_pair(output_ids[index], output_distance[index])); - } - result_set[i] = id_distance; - } - }; - -// if (NeedParallelReduce(nq, topk)) { -// ParallelReduce(reduce_worker, nq); -// } else { - reduce_worker(0, nq); -// } - - return Status::OK(); -} - -Status -XSearchTask::MergeResult(scheduler::Id2DistanceMap &distance_src, - scheduler::Id2DistanceMap &distance_target, - uint64_t topk, - bool ascending) { - //Note: the score_src and score_target are already arranged by score in ascending order - if (distance_src.empty()) { - ENGINE_LOG_WARNING << "Empty distance source array"; - return Status::OK(); - } - - std::unique_lock lock(merge_mutex_); - if (distance_target.empty()) { - distance_target.swap(distance_src); - return Status::OK(); - } - - size_t src_count = distance_src.size(); - size_t target_count = distance_target.size(); - scheduler::Id2DistanceMap distance_merged; - distance_merged.reserve(topk); - size_t src_index = 0, target_index = 0; - while (true) { - //all score_src items are merged, if score_merged.size() still less than topk - //move items from score_target to score_merged until score_merged.size() equal topk - if (src_index >= src_count) { - for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) { - distance_merged.push_back(distance_target[i]); +XSearchTask::TopkResult(const std::vector& input_ids, const std::vector& input_distance, + uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) { + scheduler::ResultSet result_buf; + + if (result.empty()) { + result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0))); + for (auto i = 0; i < nq; ++i) { + auto& result_buf_i = result_buf[i]; + uint64_t input_k_multi_i = input_k * i; + for (auto k = 0; k < input_k; ++k) { + uint64_t idx = input_k_multi_i + k; + auto& result_buf_item = result_buf_i[k]; + result_buf_item.first = input_ids[idx]; + result_buf_item.second = input_distance[idx]; } - break; } - - //all score_target items are merged, if score_merged.size() still less than topk - //move items from score_src to score_merged until score_merged.size() equal topk - if (target_index >= target_count) { - for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) { - distance_merged.push_back(distance_src[i]); + } else { + size_t tar_size = result[0].size(); + uint64_t output_k = std::min(topk, input_k + tar_size); + result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0))); + for (auto i = 0; i < nq; ++i) { + size_t buf_k = 0, src_k = 0, tar_k = 0; + uint64_t src_idx; + auto& result_i = result[i]; + auto& result_buf_i = result_buf[i]; + uint64_t input_k_multi_i = input_k * i; + while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { + src_idx = input_k_multi_i + src_k; + auto& result_buf_item = result_buf_i[buf_k]; + auto& result_item = result_i[tar_k]; + if ((ascending && input_distance[src_idx] < result_item.second) || + (!ascending && input_distance[src_idx] > result_item.second)) { + result_buf_item.first = input_ids[src_idx]; + result_buf_item.second = input_distance[src_idx]; + src_k++; + } else { + result_buf_item = result_item; + tar_k++; + } + buf_k++; } - break; - } - //compare score, - // if ascending = true, put smallest score to score_merged one by one - // else, put largest score to score_merged one by one - auto &src_pair = distance_src[src_index]; - auto &target_pair = distance_target[target_index]; - if (ascending) { - if (src_pair.second > target_pair.second) { - distance_merged.push_back(target_pair); - target_index++; - } else { - distance_merged.push_back(src_pair); - src_index++; - } - } else { - if (src_pair.second < target_pair.second) { - distance_merged.push_back(target_pair); - target_index++; - } else { - distance_merged.push_back(src_pair); - src_index++; + if (buf_k < topk) { + if (src_k < input_k) { + while (buf_k < output_k && src_k < input_k) { + src_idx = input_k_multi_i + src_k; + auto& result_buf_item = result_buf_i[buf_k]; + result_buf_item.first = input_ids[src_idx]; + result_buf_item.second = input_distance[src_idx]; + src_k++; + buf_k++; + } + } else { + while (buf_k < output_k && tar_k < tar_size) { + result_buf_i[buf_k] = result_i[tar_k]; + tar_k++; + buf_k++; + } + } } } - - //score_merged.size() already equal topk - if (distance_merged.size() >= topk) { - break; - } } - distance_target.swap(distance_merged); - - return Status::OK(); -} - -Status -XSearchTask::TopkResult(scheduler::ResultSet &result_src, - uint64_t topk, - bool ascending, - scheduler::ResultSet &result_target) { - if (result_target.empty()) { - result_target.swap(result_src); - return Status::OK(); - } - - if (result_src.size() != result_target.size()) { - std::string msg = "Invalid result set size"; - ENGINE_LOG_ERROR << msg; - return Status(DB_ERROR, msg); - } - - std::function ReduceWorker = [&](size_t from_index, size_t to_index) { - for (size_t i = from_index; i < to_index; i++) { - scheduler::Id2DistanceMap &score_src = result_src[i]; - scheduler::Id2DistanceMap &score_target = result_target[i]; - XSearchTask::MergeResult(score_src, score_target, topk, ascending); - } - }; - -// if (NeedParallelReduce(result_src.size(), topk)) { -// ParallelReduce(ReduceWorker, result_src.size()); -// } else { - ReduceWorker(0, result_src.size()); -// } + result.swap(result_buf); return Status::OK(); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/SearchTask.h b/cpp/src/scheduler/task/SearchTask.h index 7c19ba20f961966455e6aa798aa9bf356b735f34..fd5c8a0d1d5dacd838429330f9e8d0ab44cb77de 100644 --- a/cpp/src/scheduler/task/SearchTask.h +++ b/cpp/src/scheduler/task/SearchTask.h @@ -18,19 +18,18 @@ #pragma once #include "Task.h" -#include "scheduler/job/SearchJob.h" #include "scheduler/Definition.h" +#include "scheduler/job/SearchJob.h" #include -namespace zilliz { namespace milvus { namespace scheduler { -// TODO: rewrite +// TODO(wxyu): rewrite class XSearchTask : public Task { public: - explicit XSearchTask(TableFileSchemaPtr file); + explicit XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label); void Load(LoadType type, uint8_t device_id) override; @@ -39,21 +38,9 @@ class XSearchTask : public Task { Execute() override; public: - static Status ClusterResult(const std::vector &output_ids, - const std::vector &output_distence, - uint64_t nq, - uint64_t topk, - scheduler::ResultSet &result_set); - - static Status MergeResult(scheduler::Id2DistanceMap &distance_src, - scheduler::Id2DistanceMap &distance_target, - uint64_t topk, - bool ascending); - - static Status TopkResult(scheduler::ResultSet &result_src, - uint64_t topk, - bool ascending, - scheduler::ResultSet &result_target); + static Status + TopkResult(const std::vector& input_ids, const std::vector& input_distance, uint64_t input_k, + uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); public: TableFileSchemaPtr file_; @@ -66,6 +53,5 @@ class XSearchTask : public Task { static std::mutex merge_mutex_; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/Task.h b/cpp/src/scheduler/task/Task.h index 3600c10f033b5de5fe0d7cf83fa0cea02e91037b..411c18cbdeb8e14e55f51c51e7d94f1347ee0105 100644 --- a/cpp/src/scheduler/task/Task.h +++ b/cpp/src/scheduler/task/Task.h @@ -17,15 +17,15 @@ #pragma once -#include "scheduler/tasklabel/TaskLabel.h" +#include "Path.h" #include "scheduler/job/Job.h" +#include "scheduler/tasklabel/TaskLabel.h" #include "utils/Status.h" -#include "Path.h" -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace scheduler { @@ -39,6 +39,7 @@ enum class LoadType { enum class TaskType { SearchTask, DeleteTask, + BuildIndexTask, TestTask, }; @@ -49,7 +50,7 @@ using TaskPtr = std::shared_ptr; // TODO: re-design class Task { public: - explicit Task(TaskType type) : type_(type) { + explicit Task(TaskType type, TaskLabelPtr label) : type_(type), label_(std::move(label)) { } /* @@ -63,7 +64,7 @@ class Task { /* * Transport path; */ - inline Path & + inline Path& path() { return task_path_; } @@ -71,7 +72,7 @@ class Task { /* * Getter and Setter; */ - inline TaskLabelPtr & + inline TaskLabelPtr& label() { return label_; } @@ -85,12 +86,10 @@ class Task { public: Path task_path_; -// std::vector search_contexts_; scheduler::JobWPtr job_; TaskType type_; TaskLabelPtr label_ = nullptr; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/TestTask.cpp b/cpp/src/scheduler/task/TestTask.cpp index fc66e562691d85f418d7a61bc5fb2dce5ef87d30..3ec3a8ab196fc2e4d09f362b8878e349367db54c 100644 --- a/cpp/src/scheduler/task/TestTask.cpp +++ b/cpp/src/scheduler/task/TestTask.cpp @@ -18,11 +18,12 @@ #include "scheduler/task/TestTask.h" #include "cache/GpuCacheMgr.h" -namespace zilliz { +#include + namespace milvus { namespace scheduler { -TestTask::TestTask(TableFileSchemaPtr &file) : XSearchTask(file) { +TestTask::TestTask(TableFileSchemaPtr& file, TaskLabelPtr label) : XSearchTask(file, std::move(label)) { } void @@ -43,11 +44,8 @@ TestTask::Execute() { void TestTask::Wait() { std::unique_lock lock(mutex_); - cv_.wait(lock, [&] { - return done_; - }); + cv_.wait(lock, [&] { return done_; }); } -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/task/TestTask.h b/cpp/src/scheduler/task/TestTask.h index 7051080861695e85cc4697d61b02d2cdbf000a9d..99b48a8afe4534a350903db062d15c987f0abfe1 100644 --- a/cpp/src/scheduler/task/TestTask.h +++ b/cpp/src/scheduler/task/TestTask.h @@ -19,13 +19,12 @@ #include "SearchTask.h" -namespace zilliz { namespace milvus { namespace scheduler { class TestTask : public XSearchTask { public: - explicit TestTask(TableFileSchemaPtr &file); + explicit TestTask(TableFileSchemaPtr& file, TaskLabelPtr label); public: void @@ -46,6 +45,5 @@ class TestTask : public XSearchTask { std::condition_variable cv_; }; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/tasklabel/BroadcastLabel.h b/cpp/src/scheduler/tasklabel/BroadcastLabel.h index 6fca107864f59aef82abe1678303e033efb74417..f0b48afb238772093dd3fa189f9b3288c6d6bf23 100644 --- a/cpp/src/scheduler/tasklabel/BroadcastLabel.h +++ b/cpp/src/scheduler/tasklabel/BroadcastLabel.h @@ -21,7 +21,6 @@ #include -namespace zilliz { namespace milvus { namespace scheduler { @@ -33,7 +32,5 @@ class BroadcastLabel : public TaskLabel { using BroadcastLabelPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz - +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/tasklabel/DefaultLabel.h b/cpp/src/scheduler/tasklabel/DefaultLabel.h index 7943c4f7c19355e2247a8d26eec80ffcdc45d334..c2157435757cbb41392a498e9dc82964f7506206 100644 --- a/cpp/src/scheduler/tasklabel/DefaultLabel.h +++ b/cpp/src/scheduler/tasklabel/DefaultLabel.h @@ -21,7 +21,6 @@ #include -namespace zilliz { namespace milvus { namespace scheduler { @@ -33,6 +32,5 @@ class DefaultLabel : public TaskLabel { using DefaultLabelPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/tasklabel/SpecResLabel.h b/cpp/src/scheduler/tasklabel/SpecResLabel.h index cfc5aa94000a67e2e90de061ef1d008181fb528d..db2989fbc244c92c128bcade16e5a62738563cc7 100644 --- a/cpp/src/scheduler/tasklabel/SpecResLabel.h +++ b/cpp/src/scheduler/tasklabel/SpecResLabel.h @@ -18,30 +18,30 @@ #pragma once #include "TaskLabel.h" +#include "scheduler/ResourceMgr.h" -#include #include +#include -class Resource; - -using ResourceWPtr = std::weak_ptr; +// class Resource; +// +// using ResourceWPtr = std::weak_ptr; -namespace zilliz { namespace milvus { namespace scheduler { class SpecResLabel : public TaskLabel { public: - explicit SpecResLabel(const ResourceWPtr &resource) + explicit SpecResLabel(const ResourceWPtr& resource) : TaskLabel(TaskLabelType::SPECIFIED_RESOURCE), resource_(resource) { } - inline ResourceWPtr & + inline ResourceWPtr& resource() { return resource_; } - inline std::string & + inline std::string& resource_name() { return resource_name_; } @@ -53,6 +53,5 @@ class SpecResLabel : public TaskLabel { using SpecResLabelPtr = std::shared_ptr(); -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/scheduler/tasklabel/TaskLabel.h b/cpp/src/scheduler/tasklabel/TaskLabel.h index d8d404e09477822188df280ab581aeb4afa245a8..d35ce409ffbcbd4503418d9f321525c8b7f8dfe3 100644 --- a/cpp/src/scheduler/tasklabel/TaskLabel.h +++ b/cpp/src/scheduler/tasklabel/TaskLabel.h @@ -19,14 +19,13 @@ #include -namespace zilliz { namespace milvus { namespace scheduler { enum class TaskLabelType { - DEFAULT, // means can be executed in any resource - SPECIFIED_RESOURCE, // means must executing in special resource - BROADCAST, // means all enable-executor resource must execute task + DEFAULT, // means can be executed in any resource + SPECIFIED_RESOURCE, // means must executing in special resource + BROADCAST, // means all enable-executor resource must execute task }; class TaskLabel { @@ -46,6 +45,5 @@ class TaskLabel { using TaskLabelPtr = std::shared_ptr; -} // namespace scheduler -} // namespace milvus -} // namespace zilliz +} // namespace scheduler +} // namespace milvus diff --git a/cpp/src/sdk/CMakeLists.txt b/cpp/src/sdk/CMakeLists.txt index e7bf989e96ad75533898c060afc877884f2cb469..a2991a49b4a4a305bf80d8cd5074218689b23bbd 100644 --- a/cpp/src/sdk/CMakeLists.txt +++ b/cpp/src/sdk/CMakeLists.txt @@ -17,11 +17,9 @@ # under the License. #------------------------------------------------------------------------------- - -aux_source_directory(interface interface_files) - include_directories(include) +aux_source_directory(interface interface_files) aux_source_directory(grpc grpc_client_files) add_library(milvus_sdk STATIC diff --git a/cpp/src/sdk/examples/grpcsimple/main.cpp b/cpp/src/sdk/examples/grpcsimple/main.cpp index 166707259a46849b672abb0bba17153f2f405737..c31f491afb2a8363dfbeaef874e7c7a03c1c8c1f 100644 --- a/cpp/src/sdk/examples/grpcsimple/main.cpp +++ b/cpp/src/sdk/examples/grpcsimple/main.cpp @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #include #include #include @@ -24,17 +23,17 @@ #include "src/ClientTest.h" void -print_help(const std::string &app_name); +print_help(const std::string& app_name); int -main(int argc, char *argv[]) { +main(int argc, char* argv[]) { printf("Client start...\n"); std::string app_name = basename(argv[0]); - static struct option long_options[] = {{"server", optional_argument, 0, 's'}, - {"port", optional_argument, 0, 'p'}, - {"help", no_argument, 0, 'h'}, - {NULL, 0, 0, 0}}; + static struct option long_options[] = {{"server", optional_argument, nullptr, 's'}, + {"port", optional_argument, nullptr, 'p'}, + {"help", no_argument, nullptr, 'h'}, + {nullptr, 0, nullptr, 0}}; int option_index = 0; std::string address = "127.0.0.1", port = "19530"; @@ -44,19 +43,20 @@ main(int argc, char *argv[]) { while ((value = getopt_long(argc, argv, "s:p:h", long_options, &option_index)) != -1) { switch (value) { case 's': { - char *address_ptr = strdup(optarg); + char* address_ptr = strdup(optarg); address = address_ptr; free(address_ptr); break; } case 'p': { - char *port_ptr = strdup(optarg); + char* port_ptr = strdup(optarg); port = port_ptr; free(port_ptr); break; } case 'h': - default:print_help(app_name); + default: + print_help(app_name); return EXIT_SUCCESS; } } @@ -69,7 +69,7 @@ main(int argc, char *argv[]) { } void -print_help(const std::string &app_name) { +print_help(const std::string& app_name) { printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str()); printf(" Options:\n"); printf(" -s --server Server address, default 127.0.0.1\n"); diff --git a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp index e29bcc30f380b41e2303bacf6bf6ea4d3245e372..ce511714b2882312b6e417fa65b27ec3df61877d 100644 --- a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp @@ -17,16 +17,15 @@ #include "sdk/examples/grpcsimple/src/ClientTest.h" #include "MilvusApi.h" -#include "cache/CpuCacheMgr.h" -#include #include -#include -#include #include +#include +#include #include -#include +#include #include +#include //#define SET_VECTOR_IDS; @@ -40,14 +39,14 @@ constexpr int64_t TABLE_INDEX_FILE_SIZE = 1024; constexpr int64_t BATCH_ROW_COUNT = 100000; constexpr int64_t NQ = 5; constexpr int64_t TOP_K = 10; -constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different +constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different constexpr int64_t ADD_VECTOR_LOOP = 1; constexpr int64_t SECONDS_EACH_HOUR = 3600; #define BLOCK_SPLITER std::cout << "===========================================" << std::endl; void -PrintTableSchema(const milvus::TableSchema &tb_schema) { +PrintTableSchema(const milvus::TableSchema& tb_schema) { BLOCK_SPLITER std::cout << "Table name: " << tb_schema.table_name << std::endl; std::cout << "Table dimension: " << tb_schema.dimension << std::endl; @@ -55,19 +54,18 @@ PrintTableSchema(const milvus::TableSchema &tb_schema) { } void -PrintSearchResult(const std::vector> &search_record_array, - const std::vector &topk_query_result_array) { +PrintSearchResult(const std::vector>& search_record_array, + const std::vector& topk_query_result_array) { BLOCK_SPLITER std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl; int32_t index = 0; - for (auto &result : topk_query_result_array) { + for (auto& result : topk_query_result_array) { auto search_id = search_record_array[index].first; index++; - std::cout << "No." << std::to_string(index) << " vector " << std::to_string(search_id) - << " top " << std::to_string(result.query_result_arrays.size()) - << " search result:" << std::endl; - for (auto &item : result.query_result_arrays) { + std::cout << "No." << std::to_string(index) << " vector " << std::to_string(search_id) << " top " + << std::to_string(result.query_result_arrays.size()) << " search result:" << std::endl; + for (auto& item : result.query_result_arrays) { std::cout << "\t" << std::to_string(item.id) << "\tdistance:" << std::to_string(item.distance); std::cout << std::endl; } @@ -84,9 +82,9 @@ CurrentTime() { tm t; gmtime_r(&tt, &t); - std::string str = std::to_string(t.tm_year + 1900) + "_" + std::to_string(t.tm_mon + 1) - + "_" + std::to_string(t.tm_mday) + "_" + std::to_string(t.tm_hour) - + "_" + std::to_string(t.tm_min) + "_" + std::to_string(t.tm_sec); + std::string str = std::to_string(t.tm_year + 1900) + "_" + std::to_string(t.tm_mon + 1) + "_" + + std::to_string(t.tm_mday) + "_" + std::to_string(t.tm_hour) + "_" + std::to_string(t.tm_min) + + "_" + std::to_string(t.tm_sec); return str; } @@ -100,8 +98,8 @@ CurrentTmDate(int64_t offset_day = 0) { tm t; gmtime_r(&tt, &t); - std::string str = std::to_string(t.tm_year + 1900) + "-" + std::to_string(t.tm_mon + 1) - + "-" + std::to_string(t.tm_mday); + std::string str = + std::to_string(t.tm_year + 1900) + "-" + std::to_string(t.tm_mon + 1) + "-" + std::to_string(t.tm_mday); return str; } @@ -124,8 +122,7 @@ BuildTableSchema() { } void -BuildVectors(int64_t from, int64_t to, - std::vector &vector_record_array) { +BuildVectors(int64_t from, int64_t to, std::vector& vector_record_array) { if (to <= from) { return; } @@ -135,7 +132,7 @@ BuildVectors(int64_t from, int64_t to, milvus::RowRecord record; record.data.resize(TABLE_DIMENSION); for (int64_t i = 0; i < TABLE_DIMENSION; i++) { - record.data[i] = (float) (k % (i + 1)); + record.data[i] = (float)(k % (i + 1)); } vector_record_array.emplace_back(record); @@ -150,8 +147,7 @@ Sleep(int seconds) { class TimeRecorder { public: - explicit TimeRecorder(const std::string &title) - : title_(title) { + explicit TimeRecorder(const std::string& title) : title_(title) { start_ = std::chrono::system_clock::now(); } @@ -167,16 +163,15 @@ class TimeRecorder { }; void -CheckResult(const std::vector> &search_record_array, - const std::vector &topk_query_result_array) { +CheckResult(const std::vector>& search_record_array, + const std::vector& topk_query_result_array) { BLOCK_SPLITER int64_t index = 0; - for (auto &result : topk_query_result_array) { + for (auto& result : topk_query_result_array) { auto result_id = result.query_result_arrays[0].id; auto search_id = search_record_array[index++].first; if (result_id != search_id) { - std::cout << "The top 1 result is wrong: " << result_id - << " vs. " << search_id << std::endl; + std::cout << "The top 1 result is wrong: " << result_id << " vs. " << search_id << std::endl; } else { std::cout << "Check result sucessfully" << std::endl; } @@ -186,8 +181,7 @@ CheckResult(const std::vector> &search_rec void DoSearch(std::shared_ptr conn, - const std::vector> &search_record_array, - const std::string &phase_name) { + const std::vector>& search_record_array, const std::string& phase_name) { std::vector query_range_array; milvus::Range rg; rg.start_value = CurrentTmDate(); @@ -195,7 +189,7 @@ DoSearch(std::shared_ptr conn, query_range_array.emplace_back(rg); std::vector record_array; - for (auto &pair : search_record_array) { + for (auto& pair : search_record_array) { record_array.push_back(pair.second); } @@ -214,24 +208,24 @@ DoSearch(std::shared_ptr conn, PrintSearchResult(search_record_array, topk_query_result_array); CheckResult(search_record_array, topk_query_result_array); } -} // namespace +} // namespace void -ClientTest::Test(const std::string &address, const std::string &port) { +ClientTest::Test(const std::string& address, const std::string& port) { std::shared_ptr conn = milvus::Connection::Create(); - {//connect server + { // connect server milvus::ConnectParam param = {address, port}; milvus::Status stat = conn->Connect(param); std::cout << "Connect function call status: " << stat.message() << std::endl; } - {//server version + { // server version std::string version = conn->ServerVersion(); std::cout << "Server version: " << version << std::endl; } - {//sdk version + { // sdk version std::string version = conn->ClientVersion(); std::cout << "SDK version: " << version << std::endl; } @@ -241,15 +235,15 @@ ClientTest::Test(const std::string &address, const std::string &port) { milvus::Status stat = conn->ShowTables(tables); std::cout << "ShowTables function call status: " << stat.message() << std::endl; std::cout << "All tables: " << std::endl; - for (auto &table : tables) { + for (auto& table : tables) { int64_t row_count = 0; -// conn->DropTable(table); + // conn->DropTable(table); stat = conn->CountTable(table, row_count); std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl; } } - {//create table + { // create table milvus::TableSchema tb_schema = BuildTableSchema(); milvus::Status stat = conn->CreateTable(tb_schema); std::cout << "CreateTable function call status: " << stat.message() << std::endl; @@ -261,7 +255,7 @@ ClientTest::Test(const std::string &address, const std::string &port) { } } - {//describe table + { // describe table milvus::TableSchema tb_schema; milvus::Status stat = conn->DescribeTable(TABLE_NAME, tb_schema); std::cout << "DescribeTable function call status: " << stat.message() << std::endl; @@ -269,21 +263,21 @@ ClientTest::Test(const std::string &address, const std::string &port) { } std::vector> search_record_array; - {//insert vectors - for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors + { // insert vectors + for (int i = 0; i < ADD_VECTOR_LOOP; i++) { // add vectors std::vector record_array; int64_t begin_index = i * BATCH_ROW_COUNT; BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array); #ifdef SET_VECTOR_IDS record_ids.resize(ADD_VECTOR_LOOP * BATCH_ROW_COUNT); - for (auto j = begin_index; j record_ids; - //generate user defined ids + // generate user defined ids for (int k = 0; k < BATCH_ROW_COUNT; k++) { record_ids.push_back(i * BATCH_ROW_COUNT + k); } @@ -299,22 +293,21 @@ ClientTest::Test(const std::string &address, const std::string &port) { std::cout << "Returned id array count: " << record_ids.size() << std::endl; if (search_record_array.size() < NQ) { - search_record_array.push_back( - std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET])); + search_record_array.push_back(std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET])); } } } - {//search vectors without index + { // search vectors without index Sleep(2); int64_t row_count = 0; milvus::Status stat = conn->CountTable(TABLE_NAME, row_count); std::cout << TABLE_NAME << "(" << row_count << " rows)" << std::endl; -// DoSearch(conn, search_record_array, "Search without index"); + // DoSearch(conn, search_record_array, "Search without index"); } - {//wait unit build index finish + { // wait unit build index finish std::cout << "Wait until create all index done" << std::endl; milvus::IndexParam index; index.table_name = TABLE_NAME; @@ -328,19 +321,19 @@ ClientTest::Test(const std::string &address, const std::string &port) { std::cout << "DescribeIndex function call status: " << stat.message() << std::endl; } - {//preload table + { // preload table milvus::Status stat = conn->PreloadTable(TABLE_NAME); std::cout << "PreloadTable function call status: " << stat.message() << std::endl; } - {//search vectors after build index finish + { // search vectors after build index finish for (uint64_t i = 0; i < 5; ++i) { DoSearch(conn, search_record_array, "Search after build index finish"); } -// std::cout << conn->DumpTaskTables() << std::endl; + // std::cout << conn->DumpTaskTables() << std::endl; } - {//delete index + { // delete index milvus::Status stat = conn->DropIndex(TABLE_NAME); std::cout << "DropIndex function call status: " << stat.message() << std::endl; @@ -349,7 +342,7 @@ ClientTest::Test(const std::string &address, const std::string &port) { std::cout << TABLE_NAME << "(" << row_count << " rows)" << std::endl; } - {//delete by range + { // delete by range milvus::Range rg; rg.start_value = CurrentTmDate(-2); rg.end_value = CurrentTmDate(-3); @@ -358,17 +351,18 @@ ClientTest::Test(const std::string &address, const std::string &port) { std::cout << "DeleteByRange function call status: " << stat.message() << std::endl; } - {//delete table -// Status stat = conn->DropTable(TABLE_NAME); -// std::cout << "DeleteTable function call status: " << stat.message() << std::endl; + { + // delete table + // Status stat = conn->DropTable(TABLE_NAME); + // std::cout << "DeleteTable function call status: " << stat.message() << std::endl; } - {//server status + { // server status std::string status = conn->ServerStatus(); std::cout << "Server status before disconnect: " << status << std::endl; } milvus::Connection::Destroy(conn); - {//server status + { // server status std::string status = conn->ServerStatus(); std::cout << "Server status after disconnect: " << status << std::endl; } diff --git a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.h b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.h index 165c36a180efabd8476353a3dd0bb92fa3417d25..b028b63f44c35130f3ada61860893d02127ddf10 100644 --- a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.h +++ b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.h @@ -21,5 +21,6 @@ class ClientTest { public: - void Test(const std::string &address, const std::string &port); + void + Test(const std::string& address, const std::string& port); }; diff --git a/cpp/src/sdk/grpc/ClientProxy.cpp b/cpp/src/sdk/grpc/ClientProxy.cpp index f4cd1e9a38bf30306cdb49cc22e58022795df150..7e1955b04b319e2fca7b60a044b52f204c55e57e 100644 --- a/cpp/src/sdk/grpc/ClientProxy.cpp +++ b/cpp/src/sdk/grpc/ClientProxy.cpp @@ -20,24 +20,20 @@ #include "grpc/gen-milvus/milvus.grpc.pb.h" #include -#include #include +#include //#define GRPC_MULTIPLE_THREAD; namespace milvus { bool -UriCheck(const std::string &uri) { +UriCheck(const std::string& uri) { size_t index = uri.find_first_of(':', 0); - if (index == std::string::npos) { - return false; - } else { - return true; - } + return (index != std::string::npos); } Status -ClientProxy::Connect(const ConnectParam ¶m) { +ClientProxy::Connect(const ConnectParam& param) { std::string uri = param.ip_address + ":" + param.port; channel_ = ::grpc::CreateChannel(uri, ::grpc::InsecureChannelCredentials()); @@ -45,15 +41,15 @@ ClientProxy::Connect(const ConnectParam ¶m) { connected_ = true; client_ptr_ = std::make_shared(channel_); return Status::OK(); - } else { - std::string reason = "connect failed!"; - connected_ = false; - return Status(StatusCode::NotConnected, reason); } + + std::string reason = "connect failed!"; + connected_ = false; + return Status(StatusCode::NotConnected, reason); } Status -ClientProxy::Connect(const std::string &uri) { +ClientProxy::Connect(const std::string& uri) { if (!UriCheck(uri)) { return Status(StatusCode::InvalidAgument, "Invalid uri"); } @@ -71,7 +67,7 @@ ClientProxy::Connected() const { try { std::string info; return client_ptr_->Cmd(info, ""); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::NotConnected, "connection lost: " + std::string(ex.what())); } } @@ -83,7 +79,7 @@ ClientProxy::Disconnect() { connected_ = false; channel_.reset(); return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "failed to disconnect: " + std::string(ex.what())); } } @@ -94,22 +90,22 @@ ClientProxy::ClientVersion() const { } Status -ClientProxy::CreateTable(const TableSchema ¶m) { +ClientProxy::CreateTable(const TableSchema& param) { try { ::milvus::grpc::TableSchema schema; schema.set_table_name(param.table_name); schema.set_dimension(param.dimension); schema.set_index_file_size(param.index_file_size); - schema.set_metric_type((int32_t) param.metric_type); + schema.set_metric_type(static_cast(param.metric_type)); return client_ptr_->CreateTable(schema); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "failed to create table: " + std::string(ex.what())); } } bool -ClientProxy::HasTable(const std::string &table_name) { +ClientProxy::HasTable(const std::string& table_name) { Status status = Status::OK(); ::milvus::grpc::TableName grpc_table_name; grpc_table_name.set_table_name(table_name); @@ -118,34 +114,32 @@ ClientProxy::HasTable(const std::string &table_name) { } Status -ClientProxy::DropTable(const std::string &table_name) { +ClientProxy::DropTable(const std::string& table_name) { try { ::milvus::grpc::TableName grpc_table_name; grpc_table_name.set_table_name(table_name); return client_ptr_->DropTable(grpc_table_name); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "failed to drop table: " + std::string(ex.what())); } } Status -ClientProxy::CreateIndex(const IndexParam &index_param) { +ClientProxy::CreateIndex(const IndexParam& index_param) { try { - //TODO: add index params ::milvus::grpc::IndexParam grpc_index_param; grpc_index_param.set_table_name(index_param.table_name); - grpc_index_param.mutable_index()->set_index_type((int32_t) index_param.index_type); + grpc_index_param.mutable_index()->set_index_type(static_cast(index_param.index_type)); grpc_index_param.mutable_index()->set_nlist(index_param.nlist); return client_ptr_->CreateIndex(grpc_index_param); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "failed to build index: " + std::string(ex.what())); } } Status -ClientProxy::Insert(const std::string &table_name, - const std::vector &record_array, - std::vector &id_array) { +ClientProxy::Insert(const std::string& table_name, const std::vector& record_array, + std::vector& id_array) { Status status = Status::OK(); try { //////////////////////////////////////////////////////////////////////////// @@ -154,20 +148,17 @@ ClientProxy::Insert(const std::string &table_name, int thread_count = 10; std::shared_ptr<::milvus::grpc::InsertInfos> insert_info_array( - new ::milvus::grpc::InsertInfos[thread_count], - std::default_delete<::milvus::grpc::InsertInfos[]>() ); + new ::milvus::grpc::InsertInfos[thread_count], std::default_delete<::milvus::grpc::InsertInfos[]>()); - std::shared_ptr<::milvus::grpc::VectorIds> vector_ids_array( - new ::milvus::grpc::VectorIds[thread_count], - std::default_delete<::milvus::grpc::VectorIds[]>() ); + std::shared_ptr<::milvus::grpc::VectorIds> vector_ids_array(new ::milvus::grpc::VectorIds[thread_count], + std::default_delete<::milvus::grpc::VectorIds[]>()); int64_t record_count = record_array.size() / thread_count; for (size_t i = 0; i < thread_count; i++) { insert_info_array.get()[i].set_table_name(table_name); for (size_t j = i * record_count; j < record_count * (i + 1); j++) { - ::milvus::grpc::RowRecord *grpc_record = - insert_info_array.get()[i].add_row_record_array(); + ::milvus::grpc::RowRecord* grpc_record = insert_info_array.get()[i].add_row_record_array(); for (size_t k = 0; k < record_array[j].data.size(); k++) { grpc_record->add_vector_data(record_array[j].data[k]); } @@ -177,16 +168,13 @@ ClientProxy::Insert(const std::string &table_name, std::cout << "*****************************************************\n"; auto start = std::chrono::high_resolution_clock::now(); for (size_t j = 0; j < thread_count; j++) { - threads.push_back( - std::thread(&GrpcClient::InsertVector, client_ptr_, - std::ref(vector_ids_array.get()[j]), std::ref(insert_info_array.get()[j]), - std::ref(status))); + threads.push_back(std::thread(&GrpcClient::InsertVector, client_ptr_, std::ref(vector_ids_array.get()[j]), + std::ref(insert_info_array.get()[j]), std::ref(status))); } std::for_each(threads.begin(), threads.end(), std::mem_fn(&std::thread::join)); auto finish = std::chrono::high_resolution_clock::now(); - std::cout << - "InsertVector cost: " << std::chrono::duration_cast>(finish - start).count() - << "s\n"; + std::cout << "InsertVector cost: " + << std::chrono::duration_cast>(finish - start).count() << "s\n"; std::cout << "*****************************************************\n"; for (size_t i = 0; i < thread_count; i++) { @@ -198,14 +186,14 @@ ClientProxy::Insert(const std::string &table_name, ::milvus::grpc::InsertParam insert_param; insert_param.set_table_name(table_name); - for (auto &record : record_array) { - ::milvus::grpc::RowRecord *grpc_record = insert_param.add_row_record_array(); + for (auto& record : record_array) { + ::milvus::grpc::RowRecord* grpc_record = insert_param.add_row_record_array(); for (size_t i = 0; i < record.data.size(); i++) { grpc_record->add_vector_data(record.data[i]); } } - //Single thread + // Single thread ::milvus::grpc::VectorIds vector_ids; if (!id_array.empty()) { for (auto i = 0; i < id_array.size(); i++) { @@ -219,7 +207,7 @@ ClientProxy::Insert(const std::string &table_name, } } #endif - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to add vector: " + std::string(ex.what())); } @@ -227,37 +215,34 @@ ClientProxy::Insert(const std::string &table_name, } Status -ClientProxy::Search(const std::string &table_name, - const std::vector &query_record_array, - const std::vector &query_range_array, - int64_t topk, - int64_t nprobe, - std::vector &topk_query_result_array) { +ClientProxy::Search(const std::string& table_name, const std::vector& query_record_array, + const std::vector& query_range_array, int64_t topk, int64_t nprobe, + std::vector& topk_query_result_array) { try { - //step 1: convert vectors data + // step 1: convert vectors data ::milvus::grpc::SearchParam search_param; search_param.set_table_name(table_name); search_param.set_topk(topk); search_param.set_nprobe(nprobe); - for (auto &record : query_record_array) { - ::milvus::grpc::RowRecord *row_record = search_param.add_query_record_array(); - for (auto &rec : record.data) { + for (auto& record : query_record_array) { + ::milvus::grpc::RowRecord* row_record = search_param.add_query_record_array(); + for (auto& rec : record.data) { row_record->add_vector_data(rec); } } - //step 2: convert range array - for (auto &range : query_range_array) { - ::milvus::grpc::Range *grpc_range = search_param.add_query_range_array(); + // step 2: convert range array + for (auto& range : query_range_array) { + ::milvus::grpc::Range* grpc_range = search_param.add_query_range_array(); grpc_range->set_start_value(range.start_value); grpc_range->set_end_value(range.end_value); } - //step 3: search vectors + // step 3: search vectors ::milvus::grpc::TopKQueryResultList topk_query_result_list; Status status = client_ptr_->Search(topk_query_result_list, search_param); - //step 4: convert result array + // step 4: convert result array for (uint64_t i = 0; i < topk_query_result_list.topk_query_result_size(); ++i) { TopKQueryResult result; for (uint64_t j = 0; j < topk_query_result_list.topk_query_result(i).query_result_arrays_size(); ++j) { @@ -270,13 +255,13 @@ ClientProxy::Search(const std::string &table_name, topk_query_result_array.emplace_back(result); } return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to search vectors: " + std::string(ex.what())); } } Status -ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_schema) { +ClientProxy::DescribeTable(const std::string& table_name, TableSchema& table_schema) { try { ::milvus::grpc::TableSchema grpc_schema; @@ -285,27 +270,27 @@ ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_sch table_schema.table_name = grpc_schema.table_name(); table_schema.dimension = grpc_schema.dimension(); table_schema.index_file_size = grpc_schema.index_file_size(); - table_schema.metric_type = (MetricType) grpc_schema.metric_type(); + table_schema.metric_type = static_cast(grpc_schema.metric_type()); return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to describe table: " + std::string(ex.what())); } } Status -ClientProxy::CountTable(const std::string &table_name, int64_t &row_count) { +ClientProxy::CountTable(const std::string& table_name, int64_t& row_count) { try { Status status; row_count = client_ptr_->CountTable(table_name, status); return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to show tables: " + std::string(ex.what())); } } Status -ClientProxy::ShowTables(std::vector &table_array) { +ClientProxy::ShowTables(std::vector& table_array) { try { Status status; milvus::grpc::TableNameList table_name_list; @@ -316,7 +301,7 @@ ClientProxy::ShowTables(std::vector &table_array) { table_array[i] = table_name_list.table_names(i); } return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to show tables: " + std::string(ex.what())); } } @@ -328,7 +313,7 @@ ClientProxy::ServerVersion() const { std::string version; Status status = client_ptr_->Cmd(version, "version"); return version; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return ""; } } @@ -343,7 +328,7 @@ ClientProxy::ServerStatus() const { std::string dummy; Status status = client_ptr_->Cmd(dummy, ""); return "server alive"; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return "connection lost"; } } @@ -358,62 +343,62 @@ ClientProxy::DumpTaskTables() const { std::string dummy; Status status = client_ptr_->Cmd(dummy, "tasktable"); return dummy; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return "connection lost"; } } Status -ClientProxy::DeleteByRange(milvus::Range &range, const std::string &table_name) { +ClientProxy::DeleteByRange(milvus::Range& range, const std::string& table_name) { try { ::milvus::grpc::DeleteByRangeParam delete_by_range_param; delete_by_range_param.set_table_name(table_name); delete_by_range_param.mutable_range()->set_start_value(range.start_value); delete_by_range_param.mutable_range()->set_end_value(range.end_value); return client_ptr_->DeleteByRange(delete_by_range_param); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to delete by range: " + std::string(ex.what())); } } Status -ClientProxy::PreloadTable(const std::string &table_name) const { +ClientProxy::PreloadTable(const std::string& table_name) const { try { ::milvus::grpc::TableName grpc_table_name; grpc_table_name.set_table_name(table_name); Status status = client_ptr_->PreloadTable(grpc_table_name); return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to preload tables: " + std::string(ex.what())); } } Status -ClientProxy::DescribeIndex(const std::string &table_name, IndexParam &index_param) const { +ClientProxy::DescribeIndex(const std::string& table_name, IndexParam& index_param) const { try { ::milvus::grpc::TableName grpc_table_name; grpc_table_name.set_table_name(table_name); ::milvus::grpc::IndexParam grpc_index_param; Status status = client_ptr_->DescribeIndex(grpc_table_name, grpc_index_param); - index_param.index_type = (IndexType) (grpc_index_param.mutable_index()->index_type()); + index_param.index_type = static_cast(grpc_index_param.mutable_index()->index_type()); index_param.nlist = grpc_index_param.mutable_index()->nlist(); return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to describe index: " + std::string(ex.what())); } } Status -ClientProxy::DropIndex(const std::string &table_name) const { +ClientProxy::DropIndex(const std::string& table_name) const { try { ::milvus::grpc::TableName grpc_table_name; grpc_table_name.set_table_name(table_name); Status status = client_ptr_->DropIndex(grpc_table_name); return status; - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "fail to drop index: " + std::string(ex.what())); } } -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/grpc/ClientProxy.h b/cpp/src/sdk/grpc/ClientProxy.h index 044ce44847d1e8c27b6e50482cc4b05e2431e877..dbeacc138052e7c19ac34f13439569c53a243d11 100644 --- a/cpp/src/sdk/grpc/ClientProxy.h +++ b/cpp/src/sdk/grpc/ClientProxy.h @@ -17,12 +17,12 @@ #pragma once -#include "MilvusApi.h" #include "GrpcClient.h" +#include "MilvusApi.h" -#include -#include #include +#include +#include namespace milvus { @@ -30,10 +30,10 @@ class ClientProxy : public Connection { public: // Implementations of the Connection interface Status - Connect(const ConnectParam ¶m) override; + Connect(const ConnectParam& param) override; Status - Connect(const std::string &uri) override; + Connect(const std::string& uri) override; Status Connected() const override; @@ -42,38 +42,34 @@ class ClientProxy : public Connection { Disconnect() override; Status - CreateTable(const TableSchema ¶m) override; + CreateTable(const TableSchema& param) override; bool - HasTable(const std::string &table_name) override; + HasTable(const std::string& table_name) override; Status - DropTable(const std::string &table_name) override; + DropTable(const std::string& table_name) override; Status - CreateIndex(const IndexParam &index_param) override; + CreateIndex(const IndexParam& index_param) override; Status - Insert(const std::string &table_name, - const std::vector &record_array, - std::vector &id_array) override; + Insert(const std::string& table_name, const std::vector& record_array, + std::vector& id_array) override; Status - Search(const std::string &table_name, - const std::vector &query_record_array, - const std::vector &query_range_array, - int64_t topk, - int64_t nprobe, - std::vector &topk_query_result_array) override; + Search(const std::string& table_name, const std::vector& query_record_array, + const std::vector& query_range_array, int64_t topk, int64_t nprobe, + std::vector& topk_query_result_array) override; Status - DescribeTable(const std::string &table_name, TableSchema &table_schema) override; + DescribeTable(const std::string& table_name, TableSchema& table_schema) override; Status - CountTable(const std::string &table_name, int64_t &row_count) override; + CountTable(const std::string& table_name, int64_t& row_count) override; Status - ShowTables(std::vector &table_array) override; + ShowTables(std::vector& table_array) override; std::string ClientVersion() const override; @@ -88,17 +84,16 @@ class ClientProxy : public Connection { DumpTaskTables() const override; Status - DeleteByRange(Range &range, - const std::string &table_name) override; + DeleteByRange(Range& range, const std::string& table_name) override; Status - PreloadTable(const std::string &table_name) const override; + PreloadTable(const std::string& table_name) const override; Status - DescribeIndex(const std::string &table_name, IndexParam &index_param) const override; + DescribeIndex(const std::string& table_name, IndexParam& index_param) const override; Status - DropIndex(const std::string &table_name) const override; + DropIndex(const std::string& table_name) const override; private: std::shared_ptr<::grpc::Channel> channel_; @@ -108,4 +103,4 @@ class ClientProxy : public Connection { bool connected_ = false; }; -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/grpc/GrpcClient.cpp b/cpp/src/sdk/grpc/GrpcClient.cpp index df86f0a3e7402bb05be25956287a2821ef9fbbef..5c27c3b73f73cfc84a5a7f5640ab035bec52fd9e 100644 --- a/cpp/src/sdk/grpc/GrpcClient.cpp +++ b/cpp/src/sdk/grpc/GrpcClient.cpp @@ -23,9 +23,9 @@ #include #include -#include -#include #include +#include +#include using grpc::Channel; using grpc::ClientContext; @@ -35,14 +35,14 @@ using grpc::ClientWriter; using grpc::Status; namespace milvus { -GrpcClient::GrpcClient(std::shared_ptr<::grpc::Channel> &channel) +GrpcClient::GrpcClient(std::shared_ptr<::grpc::Channel>& channel) : stub_(::milvus::grpc::MilvusService::NewStub(channel)) { } GrpcClient::~GrpcClient() = default; Status -GrpcClient::CreateTable(const ::milvus::grpc::TableSchema &table_schema) { +GrpcClient::CreateTable(const ::milvus::grpc::TableSchema& table_schema) { ClientContext context; grpc::Status response; ::grpc::Status grpc_status = stub_->CreateTable(&context, table_schema, &response); @@ -60,8 +60,7 @@ GrpcClient::CreateTable(const ::milvus::grpc::TableSchema &table_schema) { } bool -GrpcClient::HasTable(const ::milvus::grpc::TableName &table_name, - Status &status) { +GrpcClient::HasTable(const ::milvus::grpc::TableName& table_name, Status& status) { ClientContext context; ::milvus::grpc::BoolReply response; ::grpc::Status grpc_status = stub_->HasTable(&context, table_name, &response); @@ -79,7 +78,7 @@ GrpcClient::HasTable(const ::milvus::grpc::TableName &table_name, } Status -GrpcClient::DropTable(const ::milvus::grpc::TableName &table_name) { +GrpcClient::DropTable(const ::milvus::grpc::TableName& table_name) { ClientContext context; grpc::Status response; ::grpc::Status grpc_status = stub_->DropTable(&context, table_name, &response); @@ -97,7 +96,7 @@ GrpcClient::DropTable(const ::milvus::grpc::TableName &table_name) { } Status -GrpcClient::CreateIndex(const ::milvus::grpc::IndexParam &index_param) { +GrpcClient::CreateIndex(const ::milvus::grpc::IndexParam& index_param) { ClientContext context; grpc::Status response; ::grpc::Status grpc_status = stub_->CreateIndex(&context, index_param, &response); @@ -115,9 +114,8 @@ GrpcClient::CreateIndex(const ::milvus::grpc::IndexParam &index_param) { } void -GrpcClient::Insert(::milvus::grpc::VectorIds &vector_ids, - const ::milvus::grpc::InsertParam &insert_param, - Status &status) { +GrpcClient::Insert(::milvus::grpc::VectorIds& vector_ids, const ::milvus::grpc::InsertParam& insert_param, + Status& status) { ClientContext context; ::grpc::Status grpc_status = stub_->Insert(&context, insert_param, &vector_ids); @@ -136,8 +134,8 @@ GrpcClient::Insert(::milvus::grpc::VectorIds &vector_ids, } Status -GrpcClient::Search(::milvus::grpc::TopKQueryResultList &topk_query_result_list, - const ::milvus::grpc::SearchParam &search_param) { +GrpcClient::Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list, + const ::milvus::grpc::SearchParam& search_param) { ::milvus::grpc::TopKQueryResult query_result; ClientContext context; ::grpc::Status grpc_status = stub_->Search(&context, search_param, &topk_query_result_list); @@ -149,16 +147,14 @@ GrpcClient::Search(::milvus::grpc::TopKQueryResultList &topk_query_result_list, } if (topk_query_result_list.status().error_code() != grpc::SUCCESS) { std::cerr << topk_query_result_list.status().reason() << std::endl; - return Status(StatusCode::ServerFailed, - topk_query_result_list.status().reason()); + return Status(StatusCode::ServerFailed, topk_query_result_list.status().reason()); } return Status::OK(); } Status -GrpcClient::DescribeTable(::milvus::grpc::TableSchema &grpc_schema, - const std::string &table_name) { +GrpcClient::DescribeTable(::milvus::grpc::TableSchema& grpc_schema, const std::string& table_name) { ClientContext context; ::milvus::grpc::TableName grpc_tablename; grpc_tablename.set_table_name(table_name); @@ -172,15 +168,14 @@ GrpcClient::DescribeTable(::milvus::grpc::TableSchema &grpc_schema, if (grpc_schema.status().error_code() != grpc::SUCCESS) { std::cerr << grpc_schema.status().reason() << std::endl; - return Status(StatusCode::ServerFailed, - grpc_schema.status().reason()); + return Status(StatusCode::ServerFailed, grpc_schema.status().reason()); } return Status::OK(); } int64_t -GrpcClient::CountTable(const std::string &table_name, Status &status) { +GrpcClient::CountTable(const std::string& table_name, Status& status) { ClientContext context; ::milvus::grpc::TableRowCount response; ::milvus::grpc::TableName grpc_tablename; @@ -204,7 +199,7 @@ GrpcClient::CountTable(const std::string &table_name, Status &status) { } Status -GrpcClient::ShowTables(milvus::grpc::TableNameList &table_name_list) { +GrpcClient::ShowTables(milvus::grpc::TableNameList& table_name_list) { ClientContext context; ::milvus::grpc::Command command; ::grpc::Status grpc_status = stub_->ShowTables(&context, command, &table_name_list); @@ -217,16 +212,14 @@ GrpcClient::ShowTables(milvus::grpc::TableNameList &table_name_list) { if (table_name_list.status().error_code() != grpc::SUCCESS) { std::cerr << table_name_list.status().reason() << std::endl; - return Status(StatusCode::ServerFailed, - table_name_list.status().reason()); + return Status(StatusCode::ServerFailed, table_name_list.status().reason()); } return Status::OK(); } Status -GrpcClient::Cmd(std::string &result, - const std::string &cmd) { +GrpcClient::Cmd(std::string& result, const std::string& cmd) { ClientContext context; ::milvus::grpc::StringReply response; ::milvus::grpc::Command command; @@ -248,7 +241,7 @@ GrpcClient::Cmd(std::string &result, } Status -GrpcClient::PreloadTable(milvus::grpc::TableName &table_name) { +GrpcClient::PreloadTable(milvus::grpc::TableName& table_name) { ClientContext context; ::milvus::grpc::Status response; ::grpc::Status grpc_status = stub_->PreloadTable(&context, table_name, &response); @@ -266,7 +259,7 @@ GrpcClient::PreloadTable(milvus::grpc::TableName &table_name) { } Status -GrpcClient::DeleteByRange(grpc::DeleteByRangeParam &delete_by_range_param) { +GrpcClient::DeleteByRange(grpc::DeleteByRangeParam& delete_by_range_param) { ClientContext context; ::milvus::grpc::Status response; ::grpc::Status grpc_status = stub_->DeleteByRange(&context, delete_by_range_param, &response); @@ -290,7 +283,7 @@ GrpcClient::Disconnect() { } Status -GrpcClient::DescribeIndex(grpc::TableName &table_name, grpc::IndexParam &index_param) { +GrpcClient::DescribeIndex(grpc::TableName& table_name, grpc::IndexParam& index_param) { ClientContext context; ::grpc::Status grpc_status = stub_->DescribeIndex(&context, table_name, &index_param); @@ -307,7 +300,7 @@ GrpcClient::DescribeIndex(grpc::TableName &table_name, grpc::IndexParam &index_p } Status -GrpcClient::DropIndex(grpc::TableName &table_name) { +GrpcClient::DropIndex(grpc::TableName& table_name) { ClientContext context; ::milvus::grpc::Status response; ::grpc::Status grpc_status = stub_->DropIndex(&context, table_name, &response); @@ -324,4 +317,4 @@ GrpcClient::DropIndex(grpc::TableName &table_name) { return Status::OK(); } -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/grpc/GrpcClient.h b/cpp/src/sdk/grpc/GrpcClient.h index 8f81e83ae8fb35bbed04ed5bf8105b2c58747e44..d2e6ae509535aa0420d6bfd83d43c8e1a9aa9a72 100644 --- a/cpp/src/sdk/grpc/GrpcClient.h +++ b/cpp/src/sdk/grpc/GrpcClient.h @@ -37,56 +37,51 @@ namespace milvus { class GrpcClient { public: - explicit GrpcClient(std::shared_ptr<::grpc::Channel> &channel); + explicit GrpcClient(std::shared_ptr<::grpc::Channel>& channel); - virtual - ~GrpcClient(); + virtual ~GrpcClient(); Status - CreateTable(const grpc::TableSchema &table_schema); + CreateTable(const grpc::TableSchema& table_schema); bool - HasTable(const grpc::TableName &table_name, Status &status); + HasTable(const grpc::TableName& table_name, Status& status); Status - DropTable(const grpc::TableName &table_name); + DropTable(const grpc::TableName& table_name); Status - CreateIndex(const grpc::IndexParam &index_param); + CreateIndex(const grpc::IndexParam& index_param); void - Insert(grpc::VectorIds &vector_ids, - const grpc::InsertParam &insert_param, - Status &status); + Insert(grpc::VectorIds& vector_ids, const grpc::InsertParam& insert_param, Status& status); Status - Search(::milvus::grpc::TopKQueryResultList &topk_query_result_list, - const grpc::SearchParam &search_param); + Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list, const grpc::SearchParam& search_param); Status - DescribeTable(grpc::TableSchema &grpc_schema, - const std::string &table_name); + DescribeTable(grpc::TableSchema& grpc_schema, const std::string& table_name); int64_t - CountTable(const std::string &table_name, Status &status); + CountTable(const std::string& table_name, Status& status); Status - ShowTables(milvus::grpc::TableNameList &table_name_list); + ShowTables(milvus::grpc::TableNameList& table_name_list); Status - Cmd(std::string &result, const std::string &cmd); + Cmd(std::string& result, const std::string& cmd); Status - DeleteByRange(grpc::DeleteByRangeParam &delete_by_range_param); + DeleteByRange(grpc::DeleteByRangeParam& delete_by_range_param); Status - PreloadTable(grpc::TableName &table_name); + PreloadTable(grpc::TableName& table_name); Status - DescribeIndex(grpc::TableName &table_name, grpc::IndexParam &index_param); + DescribeIndex(grpc::TableName& table_name, grpc::IndexParam& index_param); Status - DropIndex(grpc::TableName &table_name); + DropIndex(grpc::TableName& table_name); Status Disconnect(); @@ -95,4 +90,4 @@ class GrpcClient { std::unique_ptr stub_; }; -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/include/MilvusApi.h b/cpp/src/sdk/include/MilvusApi.h index e6025fd52e541c7c36e612f50c7d483ab901869b..9425ef3be33d6205039891d47c60afc9b077e68d 100644 --- a/cpp/src/sdk/include/MilvusApi.h +++ b/cpp/src/sdk/include/MilvusApi.h @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "Status.h" +#include #include #include -#include /** \brief Milvus SDK namespace */ @@ -48,18 +47,18 @@ enum class MetricType { * @brief Connect API parameter */ struct ConnectParam { - std::string ip_address; ///< Server IP address - std::string port; ///< Server PORT + std::string ip_address; ///< Server IP address + std::string port; ///< Server PORT }; /** * @brief Table Schema */ struct TableSchema { - std::string table_name; ///< Table name - int64_t dimension = 0; ///< Vector dimension, must be a positive value - int64_t index_file_size = 0; ///< Index file size, must be a positive value - MetricType metric_type = MetricType::L2; ///< Index metric type + std::string table_name; ///< Table name + int64_t dimension = 0; ///< Vector dimension, must be a positive value + int64_t index_file_size = 0; ///< Index file size, must be a positive value + MetricType metric_type = MetricType::L2; ///< Index metric type }; /** @@ -67,30 +66,30 @@ struct TableSchema { * for DATE partition, the format is like: 'year-month-day' */ struct Range { - std::string start_value; ///< Range start - std::string end_value; ///< Range stop + std::string start_value; ///< Range start + std::string end_value; ///< Range stop }; /** * @brief Record inserted */ struct RowRecord { - std::vector data; ///< Vector raw data + std::vector data; ///< Vector raw data }; /** * @brief Query result */ struct QueryResult { - int64_t id; ///< Output result - double distance; ///< Vector similarity distance + int64_t id; ///< Output result + double distance; ///< Vector similarity distance }; /** * @brief TopK query result */ struct TopKQueryResult { - std::vector query_result_arrays; ///< TopK query result + std::vector query_result_arrays; ///< TopK query result }; /** @@ -129,7 +128,7 @@ class Connection { */ static Status - Destroy(std::shared_ptr &connection_ptr); + Destroy(std::shared_ptr& connection_ptr); /** * @brief Connect @@ -143,7 +142,7 @@ class Connection { */ virtual Status - Connect(const ConnectParam ¶m) = 0; + Connect(const ConnectParam& param) = 0; /** * @brief Connect @@ -156,7 +155,7 @@ class Connection { * @return Indicate if connect is successful */ virtual Status - Connect(const std::string &uri) = 0; + Connect(const std::string& uri) = 0; /** * @brief connected @@ -188,7 +187,7 @@ class Connection { * @return Indicate if table is created successfully */ virtual Status - CreateTable(const TableSchema ¶m) = 0; + CreateTable(const TableSchema& param) = 0; /** * @brief Test table existence method @@ -200,7 +199,7 @@ class Connection { * @return Indicate if table is cexist */ virtual bool - HasTable(const std::string &table_name) = 0; + HasTable(const std::string& table_name) = 0; /** * @brief Delete table method @@ -212,7 +211,7 @@ class Connection { * @return Indicate if table is delete successfully. */ virtual Status - DropTable(const std::string &table_name) = 0; + DropTable(const std::string& table_name) = 0; /** * @brief Create index method @@ -228,7 +227,7 @@ class Connection { * @return Indicate if build index successfully. */ virtual Status - CreateIndex(const IndexParam &index_param) = 0; + CreateIndex(const IndexParam& index_param) = 0; /** * @brief Add vector to table @@ -242,9 +241,8 @@ class Connection { * @return Indicate if vector array are inserted successfully */ virtual Status - Insert(const std::string &table_name, - const std::vector &record_array, - std::vector &id_array) = 0; + Insert(const std::string& table_name, const std::vector& record_array, + std::vector& id_array) = 0; /** * @brief Search vector @@ -260,12 +258,9 @@ class Connection { * @return Indicate if query is successful. */ virtual Status - Search(const std::string &table_name, - const std::vector &query_record_array, - const std::vector &query_range_array, - int64_t topk, - int64_t nprobe, - std::vector &topk_query_result_array) = 0; + Search(const std::string& table_name, const std::vector& query_record_array, + const std::vector& query_range_array, int64_t topk, int64_t nprobe, + std::vector& topk_query_result_array) = 0; /** * @brief Show table description @@ -278,7 +273,7 @@ class Connection { * @return Indicate if this operation is successful. */ virtual Status - DescribeTable(const std::string &table_name, TableSchema &table_schema) = 0; + DescribeTable(const std::string& table_name, TableSchema& table_schema) = 0; /** * @brief Get table row count @@ -291,8 +286,7 @@ class Connection { * @return Indicate if this operation is successful. */ virtual Status - CountTable(const std::string &table_name, - int64_t &row_count) = 0; + CountTable(const std::string& table_name, int64_t& row_count) = 0; /** * @brief Show all tables in database @@ -304,7 +298,7 @@ class Connection { * @return Indicate if this operation is successful. */ virtual Status - ShowTables(std::vector &table_array) = 0; + ShowTables(std::vector& table_array) = 0; /** * @brief Give the client version @@ -350,8 +344,7 @@ class Connection { * @return Indicate if this operation is successful. */ virtual Status - DeleteByRange(Range &range, - const std::string &table_name) = 0; + DeleteByRange(Range& range, const std::string& table_name) = 0; /** * @brief preload table @@ -363,7 +356,7 @@ class Connection { * @return Indicate if this operation is successful. */ virtual Status - PreloadTable(const std::string &table_name) const = 0; + PreloadTable(const std::string& table_name) const = 0; /** * @brief describe index @@ -375,7 +368,7 @@ class Connection { * @return index informations and indicate if this operation is successful. */ virtual Status - DescribeIndex(const std::string &table_name, IndexParam &index_param) const = 0; + DescribeIndex(const std::string& table_name, IndexParam& index_param) const = 0; /** * @brief drop index @@ -387,7 +380,7 @@ class Connection { * @return Indicate if this operation is successful. */ virtual Status - DropIndex(const std::string &table_name) const = 0; + DropIndex(const std::string& table_name) const = 0; }; -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/include/Status.h b/cpp/src/sdk/include/Status.h index 670d9662b4a23700fb39d163225a1c5071366acf..008f9956d254a9904e0b2a2394f53fd26c1d3e66 100644 --- a/cpp/src/sdk/include/Status.h +++ b/cpp/src/sdk/include/Status.h @@ -15,18 +15,17 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include /** \brief Milvus SDK namespace -*/ + */ namespace milvus { /** -* @brief Status Code for SDK interface return -*/ + * @brief Status Code for SDK interface return + */ enum class StatusCode { OK = 0, @@ -42,23 +41,23 @@ enum class StatusCode { }; /** -* @brief Status for SDK interface return -*/ + * @brief Status for SDK interface return + */ class Status { public: - Status(StatusCode code, const std::string &msg); + Status(StatusCode code, const std::string& msg); Status(); ~Status(); - Status(const Status &s); + Status(const Status& s); - Status & - operator=(const Status &s); + Status& + operator=(const Status& s); - Status(Status &&s); + Status(Status&& s); - Status & - operator=(Status &&s); + Status& + operator=(Status&& s); static Status OK() { @@ -72,7 +71,7 @@ class Status { StatusCode code() const { - return (state_ == nullptr) ? StatusCode::OK : *(StatusCode *) (state_); + return (state_ == nullptr) ? StatusCode::OK : *(StatusCode*)(state_); } std::string @@ -80,13 +79,13 @@ class Status { private: inline void - CopyFrom(const Status &s); + CopyFrom(const Status& s); inline void - MoveFrom(Status &s); + MoveFrom(Status& s); private: - const char *state_ = nullptr; -}; // Status + char* state_ = nullptr; +}; // Status -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/interface/ConnectionImpl.cpp b/cpp/src/sdk/interface/ConnectionImpl.cpp index 5a01cd1a797bf536ad24d52b70711eeda48da1c9..7034ce4a4d03c91dfd5c451a611ef584af52f864 100644 --- a/cpp/src/sdk/interface/ConnectionImpl.cpp +++ b/cpp/src/sdk/interface/ConnectionImpl.cpp @@ -25,7 +25,7 @@ Connection::Create() { } Status -Connection::Destroy(std::shared_ptr &connection_ptr) { +Connection::Destroy(std::shared_ptr& connection_ptr) { if (connection_ptr != nullptr) { return connection_ptr->Disconnect(); } @@ -38,12 +38,12 @@ ConnectionImpl::ConnectionImpl() { } Status -ConnectionImpl::Connect(const ConnectParam ¶m) { +ConnectionImpl::Connect(const ConnectParam& param) { return client_proxy_->Connect(param); } Status -ConnectionImpl::Connect(const std::string &uri) { +ConnectionImpl::Connect(const std::string& uri) { return client_proxy_->Connect(uri); } @@ -63,55 +63,51 @@ ConnectionImpl::ClientVersion() const { } Status -ConnectionImpl::CreateTable(const TableSchema ¶m) { +ConnectionImpl::CreateTable(const TableSchema& param) { return client_proxy_->CreateTable(param); } bool -ConnectionImpl::HasTable(const std::string &table_name) { +ConnectionImpl::HasTable(const std::string& table_name) { return client_proxy_->HasTable(table_name); } Status -ConnectionImpl::DropTable(const std::string &table_name) { +ConnectionImpl::DropTable(const std::string& table_name) { return client_proxy_->DropTable(table_name); } Status -ConnectionImpl::CreateIndex(const IndexParam &index_param) { +ConnectionImpl::CreateIndex(const IndexParam& index_param) { return client_proxy_->CreateIndex(index_param); } Status -ConnectionImpl::Insert(const std::string &table_name, - const std::vector &record_array, - std::vector &id_array) { +ConnectionImpl::Insert(const std::string& table_name, const std::vector& record_array, + std::vector& id_array) { return client_proxy_->Insert(table_name, record_array, id_array); } Status -ConnectionImpl::Search(const std::string &table_name, - const std::vector &query_record_array, - const std::vector &query_range_array, - int64_t topk, - int64_t nprobe, - std::vector &topk_query_result_array) { - return client_proxy_->Search(table_name, query_record_array, query_range_array, topk, - nprobe, topk_query_result_array); +ConnectionImpl::Search(const std::string& table_name, const std::vector& query_record_array, + const std::vector& query_range_array, int64_t topk, int64_t nprobe, + std::vector& topk_query_result_array) { + return client_proxy_->Search(table_name, query_record_array, query_range_array, topk, nprobe, + topk_query_result_array); } Status -ConnectionImpl::DescribeTable(const std::string &table_name, TableSchema &table_schema) { +ConnectionImpl::DescribeTable(const std::string& table_name, TableSchema& table_schema) { return client_proxy_->DescribeTable(table_name, table_schema); } Status -ConnectionImpl::CountTable(const std::string &table_name, int64_t &row_count) { +ConnectionImpl::CountTable(const std::string& table_name, int64_t& row_count) { return client_proxy_->CountTable(table_name, row_count); } Status -ConnectionImpl::ShowTables(std::vector &table_array) { +ConnectionImpl::ShowTables(std::vector& table_array) { return client_proxy_->ShowTables(table_array); } @@ -131,24 +127,23 @@ ConnectionImpl::DumpTaskTables() const { } Status -ConnectionImpl::DeleteByRange(Range &range, - const std::string &table_name) { +ConnectionImpl::DeleteByRange(Range& range, const std::string& table_name) { return client_proxy_->DeleteByRange(range, table_name); } Status -ConnectionImpl::PreloadTable(const std::string &table_name) const { +ConnectionImpl::PreloadTable(const std::string& table_name) const { return client_proxy_->PreloadTable(table_name); } Status -ConnectionImpl::DescribeIndex(const std::string &table_name, IndexParam &index_param) const { +ConnectionImpl::DescribeIndex(const std::string& table_name, IndexParam& index_param) const { return client_proxy_->DescribeIndex(table_name, index_param); } Status -ConnectionImpl::DropIndex(const std::string &table_name) const { +ConnectionImpl::DropIndex(const std::string& table_name) const { return client_proxy_->DropIndex(table_name); } -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/interface/ConnectionImpl.h b/cpp/src/sdk/interface/ConnectionImpl.h index 7d5e09088249b58b74d0d3f90dea822f80d9528b..6bc3432bc4ef90059964db46cb640160c8efffb5 100644 --- a/cpp/src/sdk/interface/ConnectionImpl.h +++ b/cpp/src/sdk/interface/ConnectionImpl.h @@ -20,9 +20,9 @@ #include "MilvusApi.h" #include "sdk/grpc/ClientProxy.h" -#include #include #include +#include namespace milvus { @@ -32,10 +32,10 @@ class ConnectionImpl : public Connection { // Implementations of the Connection interface Status - Connect(const ConnectParam ¶m) override; + Connect(const ConnectParam& param) override; Status - Connect(const std::string &uri) override; + Connect(const std::string& uri) override; Status Connected() const override; @@ -44,38 +44,34 @@ class ConnectionImpl : public Connection { Disconnect() override; Status - CreateTable(const TableSchema ¶m) override; + CreateTable(const TableSchema& param) override; bool - HasTable(const std::string &table_name) override; + HasTable(const std::string& table_name) override; Status - DropTable(const std::string &table_name) override; + DropTable(const std::string& table_name) override; Status - CreateIndex(const IndexParam &index_param) override; + CreateIndex(const IndexParam& index_param) override; Status - Insert(const std::string &table_name, - const std::vector &record_array, - std::vector &id_array) override; + Insert(const std::string& table_name, const std::vector& record_array, + std::vector& id_array) override; Status - Search(const std::string &table_name, - const std::vector &query_record_array, - const std::vector &query_range_array, - int64_t topk, - int64_t nprobe, - std::vector &topk_query_result_array) override; + Search(const std::string& table_name, const std::vector& query_record_array, + const std::vector& query_range_array, int64_t topk, int64_t nprobe, + std::vector& topk_query_result_array) override; Status - DescribeTable(const std::string &table_name, TableSchema &table_schema) override; + DescribeTable(const std::string& table_name, TableSchema& table_schema) override; Status - CountTable(const std::string &table_name, int64_t &row_count) override; + CountTable(const std::string& table_name, int64_t& row_count) override; Status - ShowTables(std::vector &table_array) override; + ShowTables(std::vector& table_array) override; std::string ClientVersion() const override; @@ -90,20 +86,19 @@ class ConnectionImpl : public Connection { DumpTaskTables() const override; Status - DeleteByRange(Range &range, - const std::string &table_name) override; + DeleteByRange(Range& range, const std::string& table_name) override; Status - PreloadTable(const std::string &table_name) const override; + PreloadTable(const std::string& table_name) const override; Status - DescribeIndex(const std::string &table_name, IndexParam &index_param) const override; + DescribeIndex(const std::string& table_name, IndexParam& index_param) const override; Status - DropIndex(const std::string &table_name) const override; + DropIndex(const std::string& table_name) const override; private: std::shared_ptr client_proxy_; }; -} // namespace milvus +} // namespace milvus diff --git a/cpp/src/sdk/interface/Status.cpp b/cpp/src/sdk/interface/Status.cpp index a8780f2ddde94e992c5360596ca72c13aff09fa6..a5e89556f29900d4a66f97fb2fce61b3e48d3aa4 100644 --- a/cpp/src/sdk/interface/Status.cpp +++ b/cpp/src/sdk/interface/Status.cpp @@ -23,12 +23,12 @@ namespace milvus { constexpr int CODE_WIDTH = sizeof(StatusCode); -Status::Status(StatusCode code, const std::string &msg) { - //4 bytes store code - //4 bytes store message length - //the left bytes store message string - const uint32_t length = (uint32_t) msg.size(); - char *result = new char[length + sizeof(length) + CODE_WIDTH]; +Status::Status(StatusCode code, const std::string& msg) { + // 4 bytes store code + // 4 bytes store message length + // the left bytes store message string + const uint32_t length = (uint32_t)msg.size(); + auto result = new char[length + sizeof(length) + CODE_WIDTH]; memcpy(result, &code, CODE_WIDTH); memcpy(result + CODE_WIDTH, &length, sizeof(length)); memcpy(result + sizeof(length) + CODE_WIDTH, msg.data(), length); @@ -36,38 +36,35 @@ Status::Status(StatusCode code, const std::string &msg) { state_ = result; } -Status::Status() - : state_(nullptr) { +Status::Status() : state_(nullptr) { } Status::~Status() { delete state_; } -Status::Status(const Status &s) - : state_(nullptr) { +Status::Status(const Status& s) : state_(nullptr) { CopyFrom(s); } -Status & -Status::operator=(const Status &s) { +Status& +Status::operator=(const Status& s) { CopyFrom(s); return *this; } -Status::Status(Status &&s) - : state_(nullptr) { +Status::Status(Status&& s) : state_(nullptr) { MoveFrom(s); } -Status & -Status::operator=(Status &&s) { +Status& +Status::operator=(Status&& s) { MoveFrom(s); return *this; } void -Status::CopyFrom(const Status &s) { +Status::CopyFrom(const Status& s) { delete state_; state_ = nullptr; if (s.state_ == nullptr) { @@ -78,11 +75,11 @@ Status::CopyFrom(const Status &s) { memcpy(&length, s.state_ + CODE_WIDTH, sizeof(length)); int buff_len = length + sizeof(length) + CODE_WIDTH; state_ = new char[buff_len]; - memcpy((void *) state_, (void *) s.state_, buff_len); + memcpy(state_, s.state_, buff_len); } void -Status::MoveFrom(Status &s) { +Status::MoveFrom(Status& s) { delete state_; state_ = s.state_; s.state_ = nullptr; @@ -104,6 +101,4 @@ Status::message() const { return msg; } -} // namespace milvus - - +} // namespace milvus diff --git a/cpp/src/server/Config.cpp b/cpp/src/server/Config.cpp index d0b2b84539c541a94b7f5902ab8fafaeed591e50..d383eb5edd96f2da626430a5a49c2bd630550d69 100644 --- a/cpp/src/server/Config.cpp +++ b/cpp/src/server/Config.cpp @@ -17,33 +17,32 @@ #include "server/Config.h" +#include #include #include #include -#include -#include #include -#include +#include #include +#include #include "config/ConfigMgr.h" #include "utils/CommonUtil.h" #include "utils/ValidationUtil.h" -namespace zilliz { namespace milvus { namespace server { constexpr uint64_t GB = 1UL << 30; -Config & +Config& Config::GetInstance() { static Config config_inst; return config_inst; } Status -Config::LoadConfigFile(const std::string &filename) { +Config::LoadConfigFile(const std::string& filename) { if (filename.empty()) { std::cerr << "ERROR: need specify config file" << std::endl; exit(1); @@ -56,14 +55,13 @@ Config::LoadConfigFile(const std::string &filename) { } try { - ConfigMgr *mgr = const_cast(ConfigMgr::GetInstance()); + ConfigMgr* mgr = const_cast(ConfigMgr::GetInstance()); ErrorCode err = mgr->LoadConfigFile(filename); if (err != 0) { std::cerr << "Server failed to load config file: " << filename << std::endl; exit(1); } - } - catch (YAML::Exception &e) { + } catch (YAML::Exception& e) { std::cerr << "Server failed to load config file: " << filename << std::endl; exit(1); } @@ -78,100 +76,146 @@ Config::ValidateConfig() { /* server config */ std::string server_addr; s = GetServerConfigAddress(server_addr); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string server_port; s = GetServerConfigPort(server_port); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string server_mode; s = GetServerConfigDeployMode(server_mode); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string server_time_zone; s = GetServerConfigTimeZone(server_time_zone); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* db config */ std::string db_primary_path; s = GetDBConfigPrimaryPath(db_primary_path); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string db_secondary_path; s = GetDBConfigSecondaryPath(db_secondary_path); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string db_backend_url; s = GetDBConfigBackendUrl(db_backend_url); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } int32_t db_archive_disk_threshold; s = GetDBConfigArchiveDiskThreshold(db_archive_disk_threshold); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } int32_t db_archive_days_threshold; s = GetDBConfigArchiveDaysThreshold(db_archive_days_threshold); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } int32_t db_insert_buffer_size; s = GetDBConfigInsertBufferSize(db_insert_buffer_size); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } int32_t db_build_index_gpu; s = GetDBConfigBuildIndexGPU(db_build_index_gpu); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* metric config */ bool metric_enable_monitor; s = GetMetricConfigEnableMonitor(metric_enable_monitor); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string metric_collector; s = GetMetricConfigCollector(metric_collector); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string metric_prometheus_port; s = GetMetricConfigPrometheusPort(metric_prometheus_port); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* cache config */ - int32_t cache_cpu_mem_capacity; - s = GetCacheConfigCpuMemCapacity(cache_cpu_mem_capacity); - if (!s.ok()) return s; + int32_t cache_cpu_cache_capacity; + s = GetCacheConfigCpuCacheCapacity(cache_cpu_cache_capacity); + if (!s.ok()) { + return s; + } - float cache_cpu_mem_threshold; - s = GetCacheConfigCpuMemThreshold(cache_cpu_mem_threshold); - if (!s.ok()) return s; + float cache_cpu_cache_threshold; + s = GetCacheConfigCpuCacheThreshold(cache_cpu_cache_threshold); + if (!s.ok()) { + return s; + } - int32_t cache_gpu_mem_capacity; - s = GetCacheConfigGpuMemCapacity(cache_gpu_mem_capacity); - if (!s.ok()) return s; + int32_t cache_gpu_cache_capacity; + s = GetCacheConfigGpuCacheCapacity(cache_gpu_cache_capacity); + if (!s.ok()) { + return s; + } - float cache_gpu_mem_threshold; - s = GetCacheConfigGpuMemThreshold(cache_gpu_mem_threshold); - if (!s.ok()) return s; + float cache_gpu_cache_threshold; + s = GetCacheConfigGpuCacheThreshold(cache_gpu_cache_threshold); + if (!s.ok()) { + return s; + } bool cache_insert_data; s = GetCacheConfigCacheInsertData(cache_insert_data); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* engine config */ - int32_t engine_blas_threshold; - s = GetEngineConfigBlasThreshold(engine_blas_threshold); - if (!s.ok()) return s; + int32_t engine_use_blas_threshold; + s = GetEngineConfigUseBlasThreshold(engine_use_blas_threshold); + if (!s.ok()) { + return s; + } int32_t engine_omp_thread_num; s = GetEngineConfigOmpThreadNum(engine_omp_thread_num); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* resource config */ std::string resource_mode; s = GetResourceConfigMode(resource_mode); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::vector resource_pool; s = GetResourceConfigPool(resource_pool); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } return Status::OK(); } @@ -182,81 +226,125 @@ Config::ResetDefaultConfig() { /* server config */ s = SetServerConfigAddress(CONFIG_SERVER_ADDRESS_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetServerConfigPort(CONFIG_SERVER_PORT_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetServerConfigDeployMode(CONFIG_SERVER_DEPLOY_MODE_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetServerConfigTimeZone(CONFIG_SERVER_TIME_ZONE_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* db config */ s = SetDBConfigPrimaryPath(CONFIG_DB_PRIMARY_PATH_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetDBConfigSecondaryPath(CONFIG_DB_SECONDARY_PATH_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetDBConfigBackendUrl(CONFIG_DB_BACKEND_URL_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetDBConfigArchiveDiskThreshold(CONFIG_DB_ARCHIVE_DISK_THRESHOLD_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetDBConfigArchiveDaysThreshold(CONFIG_DB_ARCHIVE_DAYS_THRESHOLD_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetDBConfigInsertBufferSize(CONFIG_DB_INSERT_BUFFER_SIZE_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetDBConfigBuildIndexGPU(CONFIG_DB_BUILD_INDEX_GPU_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* metric config */ s = SetMetricConfigEnableMonitor(CONFIG_METRIC_ENABLE_MONITOR_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetMetricConfigCollector(CONFIG_METRIC_COLLECTOR_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } s = SetMetricConfigPrometheusPort(CONFIG_METRIC_PROMETHEUS_PORT_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* cache config */ - s = SetCacheConfigCpuMemCapacity(CONFIG_CACHE_CPU_MEM_CAPACITY_DEFAULT); - if (!s.ok()) return s; + s = SetCacheConfigCpuCacheCapacity(CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT); + if (!s.ok()) { + return s; + } - s = SetCacheConfigCpuMemThreshold(CONFIG_CACHE_CPU_MEM_THRESHOLD_DEFAULT); - if (!s.ok()) return s; + s = SetCacheConfigCpuCacheThreshold(CONFIG_CACHE_CPU_CACHE_THRESHOLD_DEFAULT); + if (!s.ok()) { + return s; + } - s = SetCacheConfigGpuMemCapacity(CONFIG_CACHE_GPU_MEM_CAPACITY_DEFAULT); - if (!s.ok()) return s; + s = SetCacheConfigGpuCacheCapacity(CONFIG_CACHE_GPU_CACHE_CAPACITY_DEFAULT); + if (!s.ok()) { + return s; + } - s = SetCacheConfigGpuMemThreshold(CONFIG_CACHE_GPU_MEM_THRESHOLD_DEFAULT); - if (!s.ok()) return s; + s = SetCacheConfigGpuCacheThreshold(CONFIG_CACHE_GPU_CACHE_THRESHOLD_DEFAULT); + if (!s.ok()) { + return s; + } s = SetCacheConfigCacheInsertData(CONFIG_CACHE_CACHE_INSERT_DATA_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* engine config */ - s = SetEngineConfigBlasThreshold(CONFIG_ENGINE_BLAS_THRESHOLD_DEFAULT); - if (!s.ok()) return s; + s = SetEngineConfigUseBlasThreshold(CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT); + if (!s.ok()) { + return s; + } s = SetEngineConfigOmpThreadNum(CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } /* resource config */ s = SetResourceConfigMode(CONFIG_RESOURCE_MODE_DEFAULT); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } return Status::OK(); } void -Config::PrintConfigSection(const std::string &config_node_name) { +Config::PrintConfigSection(const std::string& config_node_name) { std::cout << std::endl; std::cout << config_node_name << ":" << std::endl; if (config_map_.find(config_node_name) != config_map_.end()) { @@ -278,7 +366,7 @@ Config::PrintAll() { //////////////////////////////////////////////////////////////////////////////// Status -Config::CheckServerConfigAddress(const std::string &value) { +Config::CheckServerConfigAddress(const std::string& value) { if (!ValidationUtil::ValidateIpAddress(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid server config address: " + value); } @@ -286,7 +374,7 @@ Config::CheckServerConfigAddress(const std::string &value) { } Status -Config::CheckServerConfigPort(const std::string &value) { +Config::CheckServerConfigPort(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid server config port: " + value); } else { @@ -299,10 +387,8 @@ Config::CheckServerConfigPort(const std::string &value) { } Status -Config::CheckServerConfigDeployMode(const std::string &value) { - if (value != "single" && - value != "cluster_readonly" && - value != "cluster_writable") { +Config::CheckServerConfigDeployMode(const std::string& value) { + if (value != "single" && value != "cluster_readonly" && value != "cluster_writable") { return Status(SERVER_INVALID_ARGUMENT, "Invalid server config mode [single, cluster_readonly, cluster_writable]: " + value); } @@ -310,7 +396,7 @@ Config::CheckServerConfigDeployMode(const std::string &value) { } Status -Config::CheckServerConfigTimeZone(const std::string &value) { +Config::CheckServerConfigTimeZone(const std::string& value) { if (value.length() <= 3) { return Status(SERVER_INVALID_ARGUMENT, "Invalid server config time_zone: " + value); } else { @@ -328,7 +414,7 @@ Config::CheckServerConfigTimeZone(const std::string &value) { } Status -Config::CheckDBConfigPrimaryPath(const std::string &value) { +Config::CheckDBConfigPrimaryPath(const std::string& value) { if (value.empty()) { return Status(SERVER_INVALID_ARGUMENT, "DB config primary_path empty"); } @@ -336,12 +422,12 @@ Config::CheckDBConfigPrimaryPath(const std::string &value) { } Status -Config::CheckDBConfigSecondaryPath(const std::string &value) { +Config::CheckDBConfigSecondaryPath(const std::string& value) { return Status::OK(); } Status -Config::CheckDBConfigBackendUrl(const std::string &value) { +Config::CheckDBConfigBackendUrl(const std::string& value) { if (!ValidationUtil::ValidateDbURI(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid DB config backend_url: " + value); } @@ -349,7 +435,7 @@ Config::CheckDBConfigBackendUrl(const std::string &value) { } Status -Config::CheckDBConfigArchiveDiskThreshold(const std::string &value) { +Config::CheckDBConfigArchiveDiskThreshold(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid DB config archive_disk_threshold: " + value); } @@ -357,7 +443,7 @@ Config::CheckDBConfigArchiveDiskThreshold(const std::string &value) { } Status -Config::CheckDBConfigArchiveDaysThreshold(const std::string &value) { +Config::CheckDBConfigArchiveDaysThreshold(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid DB config archive_days_threshold: " + value); } @@ -365,7 +451,7 @@ Config::CheckDBConfigArchiveDaysThreshold(const std::string &value) { } Status -Config::CheckDBConfigInsertBufferSize(const std::string &value) { +Config::CheckDBConfigInsertBufferSize(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid DB config insert_buffer_size: " + value); } else { @@ -380,7 +466,7 @@ Config::CheckDBConfigInsertBufferSize(const std::string &value) { } Status -Config::CheckDBConfigBuildIndexGPU(const std::string &value) { +Config::CheckDBConfigBuildIndexGPU(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid DB config build_index_gpu: " + value); } else { @@ -393,7 +479,7 @@ Config::CheckDBConfigBuildIndexGPU(const std::string &value) { } Status -Config::CheckMetricConfigEnableMonitor(const std::string &value) { +Config::CheckMetricConfigEnableMonitor(const std::string& value) { if (!ValidationUtil::ValidateStringIsBool(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid metric config auto_bootup: " + value); } @@ -401,7 +487,7 @@ Config::CheckMetricConfigEnableMonitor(const std::string &value) { } Status -Config::CheckMetricConfigCollector(const std::string &value) { +Config::CheckMetricConfigCollector(const std::string& value) { if (value != "prometheus") { return Status(SERVER_INVALID_ARGUMENT, "Invalid metric config collector: " + value); } @@ -409,7 +495,7 @@ Config::CheckMetricConfigCollector(const std::string &value) { } Status -Config::CheckMetricConfigPrometheusPort(const std::string &value) { +Config::CheckMetricConfigPrometheusPort(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid metric config prometheus_port: " + value); } @@ -417,81 +503,87 @@ Config::CheckMetricConfigPrometheusPort(const std::string &value) { } Status -Config::CheckCacheConfigCpuMemCapacity(const std::string &value) { +Config::CheckCacheConfigCpuCacheCapacity(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config cpu_mem_capacity: " + value); + return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config cpu_cache_capacity: " + value); } else { uint64_t cpu_cache_capacity = std::stoi(value) * GB; uint64_t total_mem = 0, free_mem = 0; CommonUtil::GetSystemMemInfo(total_mem, free_mem); if (cpu_cache_capacity >= total_mem) { - return Status(SERVER_INVALID_ARGUMENT, "Cache config cpu_mem_capacity exceed system memory: " + value); - } else if (cpu_cache_capacity > (double) total_mem * 0.9) { - std::cerr << "Warning: cpu_mem_capacity value is too big" << std::endl; + return Status(SERVER_INVALID_ARGUMENT, "Cache config cpu_cache_capacity exceed system memory: " + value); + } else if (cpu_cache_capacity > static_cast(total_mem * 0.9)) { + std::cerr << "Warning: cpu_cache_capacity value is too big" << std::endl; } int32_t buffer_value; Status s = GetDBConfigInsertBufferSize(buffer_value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + int64_t insert_buffer_size = buffer_value * GB; if (insert_buffer_size + cpu_cache_capacity >= total_mem) { - return Status(SERVER_INVALID_ARGUMENT, "Sum of cpu_mem_capacity and buffer_size exceed system memory"); + return Status(SERVER_INVALID_ARGUMENT, "Sum of cpu_cache_capacity and buffer_size exceed system memory"); } } return Status::OK(); } Status -Config::CheckCacheConfigCpuMemThreshold(const std::string &value) { +Config::CheckCacheConfigCpuCacheThreshold(const std::string& value) { if (!ValidationUtil::ValidateStringIsFloat(value).ok()) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config cpu_mem_threshold: " + value); + return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config cpu_cache_threshold: " + value); } else { - float cpu_mem_threshold = std::stof(value); - if (cpu_mem_threshold <= 0.0 || cpu_mem_threshold >= 1.0) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config cpu_mem_threshold: " + value); + float cpu_cache_threshold = std::stof(value); + if (cpu_cache_threshold <= 0.0 || cpu_cache_threshold >= 1.0) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config cpu_cache_threshold: " + value); } } return Status::OK(); } Status -Config::CheckCacheConfigGpuMemCapacity(const std::string &value) { +Config::CheckCacheConfigGpuCacheCapacity(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { - std::cerr << "ERROR: gpu_cache_capacity " << value << " is not a number" << std::endl; + return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config gpu_cache_capacity: " + value); } else { uint64_t gpu_cache_capacity = std::stoi(value) * GB; int gpu_index; Status s = GetDBConfigBuildIndexGPU(gpu_index); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + size_t gpu_memory; if (!ValidationUtil::GetGpuMemory(gpu_index, gpu_memory).ok()) { return Status(SERVER_UNEXPECTED_ERROR, "Fail to get GPU memory for GPU device: " + std::to_string(gpu_index)); } else if (gpu_cache_capacity >= gpu_memory) { return Status(SERVER_INVALID_ARGUMENT, - "Cache config gpu_mem_capacity exceed GPU memory: " + std::to_string(gpu_memory)); - } else if (gpu_cache_capacity > (double) gpu_memory * 0.9) { - std::cerr << "Warning: gpu_mem_capacity value is too big" << std::endl; + "Cache config gpu_cache_capacity exceed GPU memory: " + std::to_string(gpu_memory)); + } else if (gpu_cache_capacity > (double)gpu_memory * 0.9) { + std::cerr << "Warning: gpu_cache_capacity value is too big" << std::endl; } } return Status::OK(); } Status -Config::CheckCacheConfigGpuMemThreshold(const std::string &value) { +Config::CheckCacheConfigGpuCacheThreshold(const std::string& value) { if (!ValidationUtil::ValidateStringIsFloat(value).ok()) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config gpu_mem_threshold: " + value); + return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config gpu_cache_threshold: " + value); } else { - float gpu_mem_threshold = std::stof(value); - if (gpu_mem_threshold <= 0.0 || gpu_mem_threshold >= 1.0) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config gpu_mem_threshold: " + value); + float gpu_cache_threshold = std::stof(value); + if (gpu_cache_threshold <= 0.0 || gpu_cache_threshold >= 1.0) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config gpu_cache_threshold: " + value); } } return Status::OK(); } Status -Config::CheckCacheConfigCacheInsertData(const std::string &value) { +Config::CheckCacheConfigCacheInsertData(const std::string& value) { if (!ValidationUtil::ValidateStringIsBool(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid cache config cache_insert_data: " + value); } @@ -499,29 +591,30 @@ Config::CheckCacheConfigCacheInsertData(const std::string &value) { } Status -Config::CheckEngineConfigBlasThreshold(const std::string &value) { +Config::CheckEngineConfigUseBlasThreshold(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid engine config blas threshold: " + value); + return Status(SERVER_INVALID_ARGUMENT, "Invalid engine config use_blas_threshold: " + value); } return Status::OK(); } Status -Config::CheckEngineConfigOmpThreadNum(const std::string &value) { +Config::CheckEngineConfigOmpThreadNum(const std::string& value) { if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid engine config omp_thread_num: " + value); - } else { - int32_t omp_thread = std::stoi(value); - uint32_t sys_thread_cnt = 8; - if (omp_thread > CommonUtil::GetSystemAvailableThreads(sys_thread_cnt)) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid engine config omp_thread_num: " + value); - } + } + + int32_t omp_thread = std::stoi(value); + uint32_t sys_thread_cnt = 8; + CommonUtil::GetSystemAvailableThreads(sys_thread_cnt); + if (omp_thread > static_cast(sys_thread_cnt)) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid engine config omp_thread_num: " + value); } return Status::OK(); } Status -Config::CheckResourceConfigMode(const std::string &value) { +Config::CheckResourceConfigMode(const std::string& value) { if (value != "simple") { return Status(SERVER_INVALID_ARGUMENT, "Invalid resource config mode: " + value); } @@ -529,7 +622,7 @@ Config::CheckResourceConfigMode(const std::string &value) { } Status -Config::CheckResourceConfigPool(const std::vector &value) { +Config::CheckResourceConfigPool(const std::vector& value) { if (value.empty()) { return Status(SERVER_INVALID_ARGUMENT, "Invalid resource config pool"); } @@ -537,462 +630,263 @@ Config::CheckResourceConfigPool(const std::vector &value) { } //////////////////////////////////////////////////////////////////////////////// -ConfigNode & -Config::GetConfigNode(const std::string &name) { - ConfigMgr *mgr = ConfigMgr::GetInstance(); - ConfigNode &root_node = mgr->GetRootNode(); +ConfigNode& +Config::GetConfigNode(const std::string& name) { + ConfigMgr* mgr = ConfigMgr::GetInstance(); + ConfigNode& root_node = mgr->GetRootNode(); return root_node.GetChild(name); } Status -Config::GetConfigValueInMem(const std::string &parent_key, - const std::string &child_key, - std::string &value) { +Config::GetConfigValueInMem(const std::string& parent_key, const std::string& child_key, std::string& value) { std::lock_guard lock(mutex_); if (config_map_.find(parent_key) != config_map_.end() && config_map_[parent_key].find(child_key) != config_map_[parent_key].end()) { value = config_map_[parent_key][child_key]; return Status::OK(); - } else { - return Status(SERVER_UNEXPECTED_ERROR, "key not exist"); } + return Status(SERVER_UNEXPECTED_ERROR, "key not exist"); } void -Config::SetConfigValueInMem(const std::string &parent_key, - const std::string &child_key, - const std::string &value) { +Config::SetConfigValueInMem(const std::string& parent_key, const std::string& child_key, const std::string& value) { std::lock_guard lock(mutex_); config_map_[parent_key][child_key] = value; } //////////////////////////////////////////////////////////////////////////////// -/* server config */ -std::string -Config::GetServerConfigStrAddress() { - std::string value; - if (!GetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_ADDRESS, value).ok()) { - value = GetConfigNode(CONFIG_SERVER).GetValue(CONFIG_SERVER_ADDRESS, - CONFIG_SERVER_ADDRESS_DEFAULT); - SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_ADDRESS, value); - } - return value; -} - std::string -Config::GetServerConfigStrPort() { +Config::GetConfigStr(const std::string& parent_key, const std::string& child_key, const std::string& default_value) { std::string value; - if (!GetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_PORT, value).ok()) { - value = GetConfigNode(CONFIG_SERVER).GetValue(CONFIG_SERVER_PORT, - CONFIG_SERVER_PORT_DEFAULT); - SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_PORT, value); + if (!GetConfigValueInMem(parent_key, child_key, value).ok()) { + value = GetConfigNode(parent_key).GetValue(child_key, default_value); + SetConfigValueInMem(parent_key, child_key, value); } return value; } -std::string -Config::GetServerConfigStrDeployMode() { - std::string value; - if (!GetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_DEPLOY_MODE, value).ok()) { - value = GetConfigNode(CONFIG_SERVER).GetValue(CONFIG_SERVER_DEPLOY_MODE, - CONFIG_SERVER_DEPLOY_MODE_DEFAULT); - SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_DEPLOY_MODE, value); - } - return value; -} - -std::string -Config::GetServerConfigStrTimeZone() { - std::string value; - if (!GetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_TIME_ZONE, value).ok()) { - value = GetConfigNode(CONFIG_SERVER).GetValue(CONFIG_SERVER_TIME_ZONE, - CONFIG_SERVER_TIME_ZONE_DEFAULT); - SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_TIME_ZONE, value); - } - return value; -} - -//////////////////////////////////////////////////////////////////////////////// -/* db config */ -std::string -Config::GetDBConfigStrPrimaryPath() { - std::string value; - if (!GetConfigValueInMem(CONFIG_DB, CONFIG_DB_PRIMARY_PATH, value).ok()) { - value = GetConfigNode(CONFIG_DB).GetValue(CONFIG_DB_PRIMARY_PATH, - CONFIG_DB_PRIMARY_PATH_DEFAULT); - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_PRIMARY_PATH, value); - } - return value; -} - -std::string -Config::GetDBConfigStrSecondaryPath() { - std::string value; - if (!GetConfigValueInMem(CONFIG_DB, CONFIG_DB_SECONDARY_PATH, value).ok()) { - value = GetConfigNode(CONFIG_DB).GetValue(CONFIG_DB_SECONDARY_PATH, - CONFIG_DB_SECONDARY_PATH_DEFAULT); - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_SECONDARY_PATH, value); - } - return value; -} - -std::string -Config::GetDBConfigStrBackendUrl() { - std::string value; - if (!GetConfigValueInMem(CONFIG_DB, CONFIG_DB_BACKEND_URL, value).ok()) { - value = GetConfigNode(CONFIG_DB).GetValue(CONFIG_DB_BACKEND_URL, - CONFIG_DB_BACKEND_URL_DEFAULT); - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_BACKEND_URL, value); - } - return value; -} - -std::string -Config::GetDBConfigStrArchiveDiskThreshold() { - std::string value; - if (!GetConfigValueInMem(CONFIG_DB, CONFIG_DB_ARCHIVE_DISK_THRESHOLD, value).ok()) { - value = GetConfigNode(CONFIG_DB).GetValue(CONFIG_DB_ARCHIVE_DISK_THRESHOLD, - CONFIG_DB_ARCHIVE_DISK_THRESHOLD_DEFAULT); - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_ARCHIVE_DISK_THRESHOLD, value); - } - return value; -} - -std::string -Config::GetDBConfigStrArchiveDaysThreshold() { - std::string value; - if (!GetConfigValueInMem(CONFIG_DB, CONFIG_DB_ARCHIVE_DAYS_THRESHOLD, value).ok()) { - value = GetConfigNode(CONFIG_DB).GetValue(CONFIG_DB_ARCHIVE_DAYS_THRESHOLD, - CONFIG_DB_ARCHIVE_DAYS_THRESHOLD_DEFAULT); - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_ARCHIVE_DAYS_THRESHOLD, value); - } - return value; -} - -std::string -Config::GetDBConfigStrInsertBufferSize() { - std::string value; - if (!GetConfigValueInMem(CONFIG_DB, CONFIG_DB_INSERT_BUFFER_SIZE, value).ok()) { - value = GetConfigNode(CONFIG_DB).GetValue(CONFIG_DB_INSERT_BUFFER_SIZE, - CONFIG_DB_INSERT_BUFFER_SIZE_DEFAULT); - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_INSERT_BUFFER_SIZE, value); - } - return value; -} - -std::string -Config::GetDBConfigStrBuildIndexGPU() { - std::string value; - if (!GetConfigValueInMem(CONFIG_DB, CONFIG_DB_BUILD_INDEX_GPU, value).ok()) { - value = GetConfigNode(CONFIG_DB).GetValue(CONFIG_DB_BUILD_INDEX_GPU, - CONFIG_DB_BUILD_INDEX_GPU_DEFAULT); - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_BUILD_INDEX_GPU, value); - } - return value; -} - -//////////////////////////////////////////////////////////////////////////////// -/* metric config */ -std::string -Config::GetMetricConfigStrEnableMonitor() { - std::string value; - if (!GetConfigValueInMem(CONFIG_METRIC, CONFIG_METRIC_ENABLE_MONITOR, value).ok()) { - value = GetConfigNode(CONFIG_METRIC).GetValue(CONFIG_METRIC_ENABLE_MONITOR, - CONFIG_METRIC_ENABLE_MONITOR_DEFAULT); - SetConfigValueInMem(CONFIG_METRIC, CONFIG_METRIC_ENABLE_MONITOR, value); - } - return value; -} - -std::string -Config::GetMetricConfigStrCollector() { - std::string value; - if (!GetConfigValueInMem(CONFIG_METRIC, CONFIG_METRIC_COLLECTOR, value).ok()) { - value = GetConfigNode(CONFIG_METRIC).GetValue(CONFIG_METRIC_COLLECTOR, - CONFIG_METRIC_COLLECTOR_DEFAULT); - SetConfigValueInMem(CONFIG_METRIC, CONFIG_METRIC_COLLECTOR, value); - } - return value; -} - -std::string -Config::GetMetricConfigStrPrometheusPort() { - std::string value; - if (!GetConfigValueInMem(CONFIG_METRIC, CONFIG_METRIC_PROMETHEUS_PORT, value).ok()) { - value = GetConfigNode(CONFIG_METRIC).GetValue(CONFIG_METRIC_PROMETHEUS_PORT, - CONFIG_METRIC_PROMETHEUS_PORT_DEFAULT); - SetConfigValueInMem(CONFIG_METRIC, CONFIG_METRIC_PROMETHEUS_PORT, value); - } - return value; -} - -//////////////////////////////////////////////////////////////////////////////// -/* cache config */ -std::string -Config::GetCacheConfigStrCpuMemCapacity() { - std::string value; - if (!GetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_CPU_MEM_CAPACITY, value).ok()) { - value = GetConfigNode(CONFIG_CACHE).GetValue(CONFIG_CACHE_CPU_MEM_CAPACITY, - CONFIG_CACHE_CPU_MEM_CAPACITY_DEFAULT); - SetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_CPU_MEM_CAPACITY, value); - } - return value; -} - -std::string -Config::GetCacheConfigStrCpuMemThreshold() { - std::string value; - if (!GetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_CPU_MEM_THRESHOLD, value).ok()) { - value = GetConfigNode(CONFIG_CACHE).GetValue(CONFIG_CACHE_CPU_MEM_THRESHOLD, - CONFIG_CACHE_CPU_MEM_THRESHOLD_DEFAULT); - SetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_CPU_MEM_THRESHOLD, value); - } - return value; -} - -std::string -Config::GetCacheConfigStrGpuMemCapacity() { - std::string value; - if (!GetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_GPU_MEM_CAPACITY, value).ok()) { - value = GetConfigNode(CONFIG_CACHE).GetValue(CONFIG_CACHE_GPU_MEM_CAPACITY, - CONFIG_CACHE_GPU_MEM_CAPACITY_DEFAULT); - SetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_GPU_MEM_CAPACITY, value); - } - return value; -} - -std::string -Config::GetCacheConfigStrGpuMemThreshold() { - std::string value; - if (!GetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_GPU_MEM_THRESHOLD, value).ok()) { - value = GetConfigNode(CONFIG_CACHE).GetValue(CONFIG_CACHE_GPU_MEM_THRESHOLD, - CONFIG_CACHE_GPU_MEM_THRESHOLD_DEFAULT); - SetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_GPU_MEM_THRESHOLD, value); - } - return value; -} - -std::string -Config::GetCacheConfigStrCacheInsertData() { - std::string value; - if (!GetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_CACHE_INSERT_DATA, value).ok()) { - value = GetConfigNode(CONFIG_CACHE).GetValue(CONFIG_CACHE_CACHE_INSERT_DATA, - CONFIG_CACHE_CACHE_INSERT_DATA_DEFAULT); - SetConfigValueInMem(CONFIG_CACHE, CONFIG_CACHE_CACHE_INSERT_DATA, value); - } - return value; -} - -//////////////////////////////////////////////////////////////////////////////// -/* engine config */ -std::string -Config::GetEngineConfigStrBlasThreshold() { - std::string value; - if (!GetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_BLAS_THRESHOLD, value).ok()) { - value = GetConfigNode(CONFIG_ENGINE).GetValue(CONFIG_ENGINE_BLAS_THRESHOLD, - CONFIG_ENGINE_BLAS_THRESHOLD_DEFAULT); - SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_BLAS_THRESHOLD, value); - } - return value; -} - -std::string -Config::GetEngineConfigStrOmpThreadNum() { - std::string value; - if (!GetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_OMP_THREAD_NUM, value).ok()) { - value = GetConfigNode(CONFIG_ENGINE).GetValue(CONFIG_ENGINE_OMP_THREAD_NUM, - CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT); - SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_OMP_THREAD_NUM, value); - } - return value; -} - -//////////////////////////////////////////////////////////////////////////////// -/* resource config */ -std::string -Config::GetResourceConfigStrMode() { - std::string value; - if (!GetConfigValueInMem(CONFIG_RESOURCE, CONFIG_RESOURCE_MODE, value).ok()) { - value = GetConfigNode(CONFIG_RESOURCE).GetValue(CONFIG_RESOURCE_MODE, - CONFIG_RESOURCE_MODE_DEFAULT); - SetConfigValueInMem(CONFIG_RESOURCE, CONFIG_RESOURCE_MODE, value); - } - return value; -} - -//////////////////////////////////////////////////////////////////////////////// Status -Config::GetServerConfigAddress(std::string &value) { - value = GetServerConfigStrAddress(); +Config::GetServerConfigAddress(std::string& value) { + value = GetConfigStr(CONFIG_SERVER, CONFIG_SERVER_ADDRESS, CONFIG_SERVER_ADDRESS_DEFAULT); return CheckServerConfigAddress(value); } Status -Config::GetServerConfigPort(std::string &value) { - value = GetServerConfigStrPort(); +Config::GetServerConfigPort(std::string& value) { + value = GetConfigStr(CONFIG_SERVER, CONFIG_SERVER_PORT, CONFIG_SERVER_PORT_DEFAULT); return CheckServerConfigPort(value); } Status -Config::GetServerConfigDeployMode(std::string &value) { - value = GetServerConfigStrDeployMode(); +Config::GetServerConfigDeployMode(std::string& value) { + value = GetConfigStr(CONFIG_SERVER, CONFIG_SERVER_DEPLOY_MODE, CONFIG_SERVER_DEPLOY_MODE_DEFAULT); return CheckServerConfigDeployMode(value); } Status -Config::GetServerConfigTimeZone(std::string &value) { - value = GetServerConfigStrTimeZone(); +Config::GetServerConfigTimeZone(std::string& value) { + value = GetConfigStr(CONFIG_SERVER, CONFIG_SERVER_TIME_ZONE, CONFIG_SERVER_TIME_ZONE_DEFAULT); return CheckServerConfigTimeZone(value); } Status -Config::GetDBConfigPrimaryPath(std::string &value) { - value = GetDBConfigStrPrimaryPath(); +Config::GetDBConfigPrimaryPath(std::string& value) { + value = GetConfigStr(CONFIG_DB, CONFIG_DB_PRIMARY_PATH, CONFIG_DB_PRIMARY_PATH_DEFAULT); return CheckDBConfigPrimaryPath(value); } Status -Config::GetDBConfigSecondaryPath(std::string &value) { - value = GetDBConfigStrSecondaryPath(); +Config::GetDBConfigSecondaryPath(std::string& value) { + value = GetConfigStr(CONFIG_DB, CONFIG_DB_SECONDARY_PATH, CONFIG_DB_SECONDARY_PATH_DEFAULT); return Status::OK(); } Status -Config::GetDBConfigBackendUrl(std::string &value) { - value = GetDBConfigStrBackendUrl(); +Config::GetDBConfigBackendUrl(std::string& value) { + value = GetConfigStr(CONFIG_DB, CONFIG_DB_BACKEND_URL, CONFIG_DB_BACKEND_URL_DEFAULT); return CheckDBConfigBackendUrl(value); } Status -Config::GetDBConfigArchiveDiskThreshold(int32_t &value) { - std::string str = GetDBConfigStrArchiveDiskThreshold(); +Config::GetDBConfigArchiveDiskThreshold(int32_t& value) { + std::string str = + GetConfigStr(CONFIG_DB, CONFIG_DB_ARCHIVE_DISK_THRESHOLD, CONFIG_DB_ARCHIVE_DISK_THRESHOLD_DEFAULT); Status s = CheckDBConfigArchiveDiskThreshold(str); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetDBConfigArchiveDaysThreshold(int32_t &value) { - std::string str = GetDBConfigStrArchiveDaysThreshold(); +Config::GetDBConfigArchiveDaysThreshold(int32_t& value) { + std::string str = + GetConfigStr(CONFIG_DB, CONFIG_DB_ARCHIVE_DAYS_THRESHOLD, CONFIG_DB_ARCHIVE_DAYS_THRESHOLD_DEFAULT); Status s = CheckDBConfigArchiveDaysThreshold(str); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetDBConfigInsertBufferSize(int32_t &value) { - std::string str = GetDBConfigStrInsertBufferSize(); +Config::GetDBConfigInsertBufferSize(int32_t& value) { + std::string str = GetConfigStr(CONFIG_DB, CONFIG_DB_INSERT_BUFFER_SIZE, CONFIG_DB_INSERT_BUFFER_SIZE_DEFAULT); Status s = CheckDBConfigInsertBufferSize(str); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetDBConfigBuildIndexGPU(int32_t &value) { - std::string str = GetDBConfigStrBuildIndexGPU(); +Config::GetDBConfigBuildIndexGPU(int32_t& value) { + std::string str = GetConfigStr(CONFIG_DB, CONFIG_DB_BUILD_INDEX_GPU, CONFIG_DB_BUILD_INDEX_GPU_DEFAULT); Status s = CheckDBConfigBuildIndexGPU(str); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetMetricConfigEnableMonitor(bool &value) { - std::string str = GetMetricConfigStrEnableMonitor(); +Config::GetDBConfigPreloadTable(std::string& value) { + value = GetConfigStr(CONFIG_DB, CONFIG_DB_PRELOAD_TABLE); + return Status::OK(); +} + +Status +Config::GetMetricConfigEnableMonitor(bool& value) { + std::string str = GetConfigStr(CONFIG_METRIC, CONFIG_METRIC_ENABLE_MONITOR, CONFIG_METRIC_ENABLE_MONITOR_DEFAULT); Status s = CheckMetricConfigEnableMonitor(str); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + std::transform(str.begin(), str.end(), str.begin(), ::tolower); value = (str == "true" || str == "on" || str == "yes" || str == "1"); return Status::OK(); } Status -Config::GetMetricConfigCollector(std::string &value) { - value = GetMetricConfigStrCollector(); +Config::GetMetricConfigCollector(std::string& value) { + value = GetConfigStr(CONFIG_METRIC, CONFIG_METRIC_COLLECTOR, CONFIG_METRIC_COLLECTOR_DEFAULT); return Status::OK(); } Status -Config::GetMetricConfigPrometheusPort(std::string &value) { - value = GetMetricConfigStrPrometheusPort(); +Config::GetMetricConfigPrometheusPort(std::string& value) { + value = GetConfigStr(CONFIG_METRIC, CONFIG_METRIC_PROMETHEUS_PORT, CONFIG_METRIC_PROMETHEUS_PORT_DEFAULT); return CheckMetricConfigPrometheusPort(value); } Status -Config::GetCacheConfigCpuMemCapacity(int32_t &value) { - std::string str = GetCacheConfigStrCpuMemCapacity(); - Status s = CheckCacheConfigCpuMemCapacity(str); - if (!s.ok()) return s; +Config::GetCacheConfigCpuCacheCapacity(int32_t& value) { + std::string str = + GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT); + Status s = CheckCacheConfigCpuCacheCapacity(str); + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetCacheConfigCpuMemThreshold(float &value) { - std::string str = GetCacheConfigStrCpuMemThreshold(); - Status s = CheckCacheConfigCpuMemThreshold(str); - if (!s.ok()) return s; +Config::GetCacheConfigCpuCacheThreshold(float& value) { + std::string str = + GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_THRESHOLD, CONFIG_CACHE_CPU_CACHE_THRESHOLD_DEFAULT); + Status s = CheckCacheConfigCpuCacheThreshold(str); + if (!s.ok()) { + return s; + } + value = std::stof(str); return Status::OK(); } Status -Config::GetCacheConfigGpuMemCapacity(int32_t &value) { - std::string str = GetCacheConfigStrGpuMemCapacity(); - Status s = CheckCacheConfigGpuMemCapacity(str); - if (!s.ok()) return s; +Config::GetCacheConfigGpuCacheCapacity(int32_t& value) { + std::string str = + GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_GPU_CACHE_CAPACITY, CONFIG_CACHE_GPU_CACHE_CAPACITY_DEFAULT); + Status s = CheckCacheConfigGpuCacheCapacity(str); + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetCacheConfigGpuMemThreshold(float &value) { - std::string str = GetCacheConfigStrGpuMemThreshold(); - Status s = CheckCacheConfigGpuMemThreshold(str); - if (!s.ok()) return s; +Config::GetCacheConfigGpuCacheThreshold(float& value) { + std::string str = + GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_GPU_CACHE_THRESHOLD, CONFIG_CACHE_GPU_CACHE_THRESHOLD_DEFAULT); + Status s = CheckCacheConfigGpuCacheThreshold(str); + if (!s.ok()) { + return s; + } + value = std::stof(str); return Status::OK(); } Status -Config::GetCacheConfigCacheInsertData(bool &value) { - std::string str = GetCacheConfigStrCacheInsertData(); +Config::GetCacheConfigCacheInsertData(bool& value) { + std::string str = + GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CACHE_INSERT_DATA, CONFIG_CACHE_CACHE_INSERT_DATA_DEFAULT); Status s = CheckCacheConfigCacheInsertData(str); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + std::transform(str.begin(), str.end(), str.begin(), ::tolower); value = (str == "true" || str == "on" || str == "yes" || str == "1"); return Status::OK(); } Status -Config::GetEngineConfigBlasThreshold(int32_t &value) { - std::string str = GetEngineConfigStrBlasThreshold(); - Status s = CheckEngineConfigBlasThreshold(str); - if (!s.ok()) return s; +Config::GetEngineConfigUseBlasThreshold(int32_t& value) { + std::string str = + GetConfigStr(CONFIG_ENGINE, CONFIG_ENGINE_USE_BLAS_THRESHOLD, CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT); + Status s = CheckEngineConfigUseBlasThreshold(str); + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetEngineConfigOmpThreadNum(int32_t &value) { - std::string str = GetEngineConfigStrOmpThreadNum(); +Config::GetEngineConfigOmpThreadNum(int32_t& value) { + std::string str = GetConfigStr(CONFIG_ENGINE, CONFIG_ENGINE_OMP_THREAD_NUM, CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT); Status s = CheckEngineConfigOmpThreadNum(str); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + value = std::stoi(str); return Status::OK(); } Status -Config::GetResourceConfigMode(std::string &value) { - value = GetResourceConfigStrMode(); +Config::GetResourceConfigMode(std::string& value) { + value = GetConfigStr(CONFIG_RESOURCE, CONFIG_RESOURCE_MODE, CONFIG_RESOURCE_MODE_DEFAULT); return CheckResourceConfigMode(value); } Status -Config::GetResourceConfigPool(std::vector &value) { +Config::GetResourceConfigPool(std::vector& value) { ConfigNode resource_config = GetConfigNode(CONFIG_RESOURCE); value = resource_config.GetSequence(CONFIG_RESOURCE_POOL); return CheckResourceConfigPool(value); @@ -1001,187 +895,251 @@ Config::GetResourceConfigPool(std::vector &value) { /////////////////////////////////////////////////////////////////////////////// /* server config */ Status -Config::SetServerConfigAddress(const std::string &value) { +Config::SetServerConfigAddress(const std::string& value) { Status s = CheckServerConfigAddress(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_ADDRESS, value); return Status::OK(); } Status -Config::SetServerConfigPort(const std::string &value) { +Config::SetServerConfigPort(const std::string& value) { Status s = CheckServerConfigPort(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_PORT, value); return Status::OK(); } Status -Config::SetServerConfigDeployMode(const std::string &value) { +Config::SetServerConfigDeployMode(const std::string& value) { Status s = CheckServerConfigDeployMode(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_DEPLOY_MODE, value); return Status::OK(); } Status -Config::SetServerConfigTimeZone(const std::string &value) { +Config::SetServerConfigTimeZone(const std::string& value) { Status s = CheckServerConfigTimeZone(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_SERVER, CONFIG_SERVER_TIME_ZONE, value); return Status::OK(); } /* db config */ Status -Config::SetDBConfigPrimaryPath(const std::string &value) { +Config::SetDBConfigPrimaryPath(const std::string& value) { Status s = CheckDBConfigPrimaryPath(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_DB_PRIMARY_PATH, value); return Status::OK(); } Status -Config::SetDBConfigSecondaryPath(const std::string &value) { +Config::SetDBConfigSecondaryPath(const std::string& value) { Status s = CheckDBConfigSecondaryPath(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_DB_SECONDARY_PATH, value); return Status::OK(); } Status -Config::SetDBConfigBackendUrl(const std::string &value) { +Config::SetDBConfigBackendUrl(const std::string& value) { Status s = CheckDBConfigBackendUrl(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_DB_BACKEND_URL, value); return Status::OK(); } Status -Config::SetDBConfigArchiveDiskThreshold(const std::string &value) { +Config::SetDBConfigArchiveDiskThreshold(const std::string& value) { Status s = CheckDBConfigArchiveDiskThreshold(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_DB_ARCHIVE_DISK_THRESHOLD, value); return Status::OK(); } Status -Config::SetDBConfigArchiveDaysThreshold(const std::string &value) { +Config::SetDBConfigArchiveDaysThreshold(const std::string& value) { Status s = CheckDBConfigArchiveDaysThreshold(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_DB_ARCHIVE_DAYS_THRESHOLD, value); return Status::OK(); } Status -Config::SetDBConfigInsertBufferSize(const std::string &value) { +Config::SetDBConfigInsertBufferSize(const std::string& value) { Status s = CheckDBConfigInsertBufferSize(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_DB_INSERT_BUFFER_SIZE, value); return Status::OK(); } Status -Config::SetDBConfigBuildIndexGPU(const std::string &value) { +Config::SetDBConfigBuildIndexGPU(const std::string& value) { Status s = CheckDBConfigBuildIndexGPU(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_DB_BUILD_INDEX_GPU, value); return Status::OK(); } /* metric config */ Status -Config::SetMetricConfigEnableMonitor(const std::string &value) { +Config::SetMetricConfigEnableMonitor(const std::string& value) { Status s = CheckMetricConfigEnableMonitor(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_METRIC_ENABLE_MONITOR, value); return Status::OK(); } Status -Config::SetMetricConfigCollector(const std::string &value) { +Config::SetMetricConfigCollector(const std::string& value) { Status s = CheckMetricConfigCollector(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_METRIC_COLLECTOR, value); return Status::OK(); } Status -Config::SetMetricConfigPrometheusPort(const std::string &value) { +Config::SetMetricConfigPrometheusPort(const std::string& value) { Status s = CheckMetricConfigPrometheusPort(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_METRIC_PROMETHEUS_PORT, value); return Status::OK(); } /* cache config */ Status -Config::SetCacheConfigCpuMemCapacity(const std::string &value) { - Status s = CheckCacheConfigCpuMemCapacity(value); - if (!s.ok()) return s; - SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_CPU_MEM_CAPACITY, value); +Config::SetCacheConfigCpuCacheCapacity(const std::string& value) { + Status s = CheckCacheConfigCpuCacheCapacity(value); + if (!s.ok()) { + return s; + } + + SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_CPU_CACHE_CAPACITY, value); return Status::OK(); } Status -Config::SetCacheConfigCpuMemThreshold(const std::string &value) { - Status s = CheckCacheConfigCpuMemThreshold(value); - if (!s.ok()) return s; - SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_CPU_MEM_THRESHOLD, value); +Config::SetCacheConfigCpuCacheThreshold(const std::string& value) { + Status s = CheckCacheConfigCpuCacheThreshold(value); + if (!s.ok()) { + return s; + } + + SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_CPU_CACHE_THRESHOLD, value); return Status::OK(); } Status -Config::SetCacheConfigGpuMemCapacity(const std::string &value) { - Status s = CheckCacheConfigGpuMemCapacity(value); - if (!s.ok()) return s; - SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_GPU_MEM_CAPACITY, value); +Config::SetCacheConfigGpuCacheCapacity(const std::string& value) { + Status s = CheckCacheConfigGpuCacheCapacity(value); + if (!s.ok()) { + return s; + } + + SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_GPU_CACHE_CAPACITY, value); return Status::OK(); } Status -Config::SetCacheConfigGpuMemThreshold(const std::string &value) { - Status s = CheckCacheConfigGpuMemThreshold(value); - if (!s.ok()) return s; - SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_GPU_MEM_THRESHOLD, value); +Config::SetCacheConfigGpuCacheThreshold(const std::string& value) { + Status s = CheckCacheConfigGpuCacheThreshold(value); + if (!s.ok()) { + return s; + } + + SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_GPU_CACHE_THRESHOLD, value); return Status::OK(); } Status -Config::SetCacheConfigCacheInsertData(const std::string &value) { +Config::SetCacheConfigCacheInsertData(const std::string& value) { Status s = CheckCacheConfigCacheInsertData(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_CACHE_CACHE_INSERT_DATA, value); return Status::OK(); } /* engine config */ Status -Config::SetEngineConfigBlasThreshold(const std::string &value) { - Status s = CheckEngineConfigBlasThreshold(value); - if (!s.ok()) return s; - SetConfigValueInMem(CONFIG_DB, CONFIG_ENGINE_BLAS_THRESHOLD, value); +Config::SetEngineConfigUseBlasThreshold(const std::string& value) { + Status s = CheckEngineConfigUseBlasThreshold(value); + if (!s.ok()) { + return s; + } + + SetConfigValueInMem(CONFIG_DB, CONFIG_ENGINE_USE_BLAS_THRESHOLD, value); return Status::OK(); } Status -Config::SetEngineConfigOmpThreadNum(const std::string &value) { +Config::SetEngineConfigOmpThreadNum(const std::string& value) { Status s = CheckEngineConfigOmpThreadNum(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_ENGINE_OMP_THREAD_NUM, value); return Status::OK(); } /* resource config */ Status -Config::SetResourceConfigMode(const std::string &value) { +Config::SetResourceConfigMode(const std::string& value) { Status s = CheckResourceConfigMode(value); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + SetConfigValueInMem(CONFIG_DB, CONFIG_RESOURCE_MODE, value); return Status::OK(); } -} // namespace server -} // namespace milvus -} // namespace zilliz - +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/Config.h b/cpp/src/server/Config.h index af23dd67b60c3f1d50abb25eaeffa3445c01dd31..fb00498b960dca86235a324ecdc7ecc6e791424c 100644 --- a/cpp/src/server/Config.h +++ b/cpp/src/server/Config.h @@ -17,252 +17,293 @@ #pragma once -#include -#include +#include #include +#include #include -#include +#include -#include "utils/Status.h" #include "config/ConfigNode.h" +#include "utils/Status.h" -namespace zilliz { namespace milvus { namespace server { /* server config */ -static const char *CONFIG_SERVER = "server_config"; -static const char *CONFIG_SERVER_ADDRESS = "address"; -static const char *CONFIG_SERVER_ADDRESS_DEFAULT = "127.0.0.1"; -static const char *CONFIG_SERVER_PORT = "port"; -static const char *CONFIG_SERVER_PORT_DEFAULT = "19530"; -static const char *CONFIG_SERVER_DEPLOY_MODE = "deploy_mode"; -static const char *CONFIG_SERVER_DEPLOY_MODE_DEFAULT = "single"; -static const char *CONFIG_SERVER_TIME_ZONE = "time_zone"; -static const char *CONFIG_SERVER_TIME_ZONE_DEFAULT = "UTC+8"; +static const char* CONFIG_SERVER = "server_config"; +static const char* CONFIG_SERVER_ADDRESS = "address"; +static const char* CONFIG_SERVER_ADDRESS_DEFAULT = "127.0.0.1"; +static const char* CONFIG_SERVER_PORT = "port"; +static const char* CONFIG_SERVER_PORT_DEFAULT = "19530"; +static const char* CONFIG_SERVER_DEPLOY_MODE = "deploy_mode"; +static const char* CONFIG_SERVER_DEPLOY_MODE_DEFAULT = "single"; +static const char* CONFIG_SERVER_TIME_ZONE = "time_zone"; +static const char* CONFIG_SERVER_TIME_ZONE_DEFAULT = "UTC+8"; /* db config */ -static const char *CONFIG_DB = "db_config"; -static const char *CONFIG_DB_PRIMARY_PATH = "primary_path"; -static const char *CONFIG_DB_PRIMARY_PATH_DEFAULT = "/tmp/milvus"; -static const char *CONFIG_DB_SECONDARY_PATH = "secondary_path"; -static const char *CONFIG_DB_SECONDARY_PATH_DEFAULT = ""; -static const char *CONFIG_DB_BACKEND_URL = "backend_url"; -static const char *CONFIG_DB_BACKEND_URL_DEFAULT = "sqlite://:@:/"; -static const char *CONFIG_DB_ARCHIVE_DISK_THRESHOLD = "archive_disk_threshold"; -static const char *CONFIG_DB_ARCHIVE_DISK_THRESHOLD_DEFAULT = "0"; -static const char *CONFIG_DB_ARCHIVE_DAYS_THRESHOLD = "archive_days_threshold"; -static const char *CONFIG_DB_ARCHIVE_DAYS_THRESHOLD_DEFAULT = "0"; -static const char *CONFIG_DB_INSERT_BUFFER_SIZE = "insert_buffer_size"; -static const char *CONFIG_DB_INSERT_BUFFER_SIZE_DEFAULT = "4"; -static const char *CONFIG_DB_BUILD_INDEX_GPU = "build_index_gpu"; -static const char *CONFIG_DB_BUILD_INDEX_GPU_DEFAULT = "0"; +static const char* CONFIG_DB = "db_config"; +static const char* CONFIG_DB_PRIMARY_PATH = "primary_path"; +static const char* CONFIG_DB_PRIMARY_PATH_DEFAULT = "/tmp/milvus"; +static const char* CONFIG_DB_SECONDARY_PATH = "secondary_path"; +static const char* CONFIG_DB_SECONDARY_PATH_DEFAULT = ""; +static const char* CONFIG_DB_BACKEND_URL = "backend_url"; +static const char* CONFIG_DB_BACKEND_URL_DEFAULT = "sqlite://:@:/"; +static const char* CONFIG_DB_ARCHIVE_DISK_THRESHOLD = "archive_disk_threshold"; +static const char* CONFIG_DB_ARCHIVE_DISK_THRESHOLD_DEFAULT = "0"; +static const char* CONFIG_DB_ARCHIVE_DAYS_THRESHOLD = "archive_days_threshold"; +static const char* CONFIG_DB_ARCHIVE_DAYS_THRESHOLD_DEFAULT = "0"; +static const char* CONFIG_DB_INSERT_BUFFER_SIZE = "insert_buffer_size"; +static const char* CONFIG_DB_INSERT_BUFFER_SIZE_DEFAULT = "4"; +static const char* CONFIG_DB_BUILD_INDEX_GPU = "build_index_gpu"; +static const char* CONFIG_DB_BUILD_INDEX_GPU_DEFAULT = "0"; +static const char* CONFIG_DB_PRELOAD_TABLE = "preload_table"; /* cache config */ -static const char *CONFIG_CACHE = "cache_config"; -static const char *CONFIG_CACHE_CPU_MEM_CAPACITY = "cpu_mem_capacity"; -static const char *CONFIG_CACHE_CPU_MEM_CAPACITY_DEFAULT = "16"; -static const char *CONFIG_CACHE_GPU_MEM_CAPACITY = "gpu_mem_capacity"; -static const char *CONFIG_CACHE_GPU_MEM_CAPACITY_DEFAULT = "0"; -static const char *CONFIG_CACHE_CPU_MEM_THRESHOLD = "cpu_mem_threshold"; -static const char *CONFIG_CACHE_CPU_MEM_THRESHOLD_DEFAULT = "0.85"; -static const char *CONFIG_CACHE_GPU_MEM_THRESHOLD = "gpu_mem_threshold"; -static const char *CONFIG_CACHE_GPU_MEM_THRESHOLD_DEFAULT = "0.85"; -static const char *CONFIG_CACHE_CACHE_INSERT_DATA = "cache_insert_data"; -static const char *CONFIG_CACHE_CACHE_INSERT_DATA_DEFAULT = "false"; +static const char* CONFIG_CACHE = "cache_config"; +static const char* CONFIG_CACHE_CPU_CACHE_CAPACITY = "cpu_cache_capacity"; +static const char* CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT = "16"; +static const char* CONFIG_CACHE_GPU_CACHE_CAPACITY = "gpu_cache_capacity"; +static const char* CONFIG_CACHE_GPU_CACHE_CAPACITY_DEFAULT = "0"; +static const char* CONFIG_CACHE_CPU_CACHE_THRESHOLD = "cpu_mem_threshold"; +static const char* CONFIG_CACHE_CPU_CACHE_THRESHOLD_DEFAULT = "0.85"; +static const char* CONFIG_CACHE_GPU_CACHE_THRESHOLD = "gpu_mem_threshold"; +static const char* CONFIG_CACHE_GPU_CACHE_THRESHOLD_DEFAULT = "0.85"; +static const char* CONFIG_CACHE_CACHE_INSERT_DATA = "cache_insert_data"; +static const char* CONFIG_CACHE_CACHE_INSERT_DATA_DEFAULT = "false"; /* metric config */ -static const char *CONFIG_METRIC = "metric_config"; -static const char *CONFIG_METRIC_ENABLE_MONITOR = "enable_monitor"; -static const char *CONFIG_METRIC_ENABLE_MONITOR_DEFAULT = "false"; -static const char *CONFIG_METRIC_COLLECTOR = "collector"; -static const char *CONFIG_METRIC_COLLECTOR_DEFAULT = "prometheus"; -static const char *CONFIG_METRIC_PROMETHEUS = "prometheus_config"; -static const char *CONFIG_METRIC_PROMETHEUS_PORT = "port"; -static const char *CONFIG_METRIC_PROMETHEUS_PORT_DEFAULT = "8080"; +static const char* CONFIG_METRIC = "metric_config"; +static const char* CONFIG_METRIC_ENABLE_MONITOR = "enable_monitor"; +static const char* CONFIG_METRIC_ENABLE_MONITOR_DEFAULT = "false"; +static const char* CONFIG_METRIC_COLLECTOR = "collector"; +static const char* CONFIG_METRIC_COLLECTOR_DEFAULT = "prometheus"; +static const char* CONFIG_METRIC_PROMETHEUS = "prometheus_config"; +static const char* CONFIG_METRIC_PROMETHEUS_PORT = "port"; +static const char* CONFIG_METRIC_PROMETHEUS_PORT_DEFAULT = "8080"; /* engine config */ -static const char *CONFIG_ENGINE = "engine_config"; -static const char *CONFIG_ENGINE_BLAS_THRESHOLD = "blas_threshold"; -static const char *CONFIG_ENGINE_BLAS_THRESHOLD_DEFAULT = "20"; -static const char *CONFIG_ENGINE_OMP_THREAD_NUM = "omp_thread_num"; -static const char *CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT = "0"; +static const char* CONFIG_ENGINE = "engine_config"; +static const char* CONFIG_ENGINE_USE_BLAS_THRESHOLD = "use_blas_threshold"; +static const char* CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT = "20"; +static const char* CONFIG_ENGINE_OMP_THREAD_NUM = "omp_thread_num"; +static const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT = "0"; /* resource config */ -static const char *CONFIG_RESOURCE = "resource_config"; -static const char *CONFIG_RESOURCE_MODE = "mode"; -static const char *CONFIG_RESOURCE_MODE_DEFAULT = "simple"; -static const char *CONFIG_RESOURCE_POOL = "resource_pool"; +static const char* CONFIG_RESOURCE = "resource_config"; +static const char* CONFIG_RESOURCE_MODE = "mode"; +static const char* CONFIG_RESOURCE_MODE_DEFAULT = "simple"; +static const char* CONFIG_RESOURCE_POOL = "resource_pool"; class Config { public: - static Config &GetInstance(); - Status LoadConfigFile(const std::string &filename); - Status ValidateConfig(); - Status ResetDefaultConfig(); - void PrintAll(); + static Config& + GetInstance(); + Status + LoadConfigFile(const std::string& filename); + Status + ValidateConfig(); + Status + ResetDefaultConfig(); + void + PrintAll(); private: - ConfigNode &GetConfigNode(const std::string &name); + ConfigNode& + GetConfigNode(const std::string& name); - Status GetConfigValueInMem(const std::string &parent_key, - const std::string &child_key, - std::string &value); + Status + GetConfigValueInMem(const std::string& parent_key, const std::string& child_key, std::string& value); - void SetConfigValueInMem(const std::string &parent_key, - const std::string &child_key, - const std::string &value); + void + SetConfigValueInMem(const std::string& parent_key, const std::string& child_key, const std::string& value); - void PrintConfigSection(const std::string &config_node_name); + void + PrintConfigSection(const std::string& config_node_name); /////////////////////////////////////////////////////////////////////////// /* server config */ - Status CheckServerConfigAddress(const std::string &value); - Status CheckServerConfigPort(const std::string &value); - Status CheckServerConfigDeployMode(const std::string &value); - Status CheckServerConfigTimeZone(const std::string &value); + Status + CheckServerConfigAddress(const std::string& value); + Status + CheckServerConfigPort(const std::string& value); + Status + CheckServerConfigDeployMode(const std::string& value); + Status + CheckServerConfigTimeZone(const std::string& value); /* db config */ - Status CheckDBConfigPrimaryPath(const std::string &value); - Status CheckDBConfigSecondaryPath(const std::string &value); - Status CheckDBConfigBackendUrl(const std::string &value); - Status CheckDBConfigArchiveDiskThreshold(const std::string &value); - Status CheckDBConfigArchiveDaysThreshold(const std::string &value); - Status CheckDBConfigInsertBufferSize(const std::string &value); - Status CheckDBConfigBuildIndexGPU(const std::string &value); + Status + CheckDBConfigPrimaryPath(const std::string& value); + Status + CheckDBConfigSecondaryPath(const std::string& value); + Status + CheckDBConfigBackendUrl(const std::string& value); + Status + CheckDBConfigArchiveDiskThreshold(const std::string& value); + Status + CheckDBConfigArchiveDaysThreshold(const std::string& value); + Status + CheckDBConfigInsertBufferSize(const std::string& value); + Status + CheckDBConfigBuildIndexGPU(const std::string& value); /* metric config */ - Status CheckMetricConfigEnableMonitor(const std::string &value); - Status CheckMetricConfigCollector(const std::string &value); - Status CheckMetricConfigPrometheusPort(const std::string &value); + Status + CheckMetricConfigEnableMonitor(const std::string& value); + Status + CheckMetricConfigCollector(const std::string& value); + Status + CheckMetricConfigPrometheusPort(const std::string& value); /* cache config */ - Status CheckCacheConfigCpuMemCapacity(const std::string &value); - Status CheckCacheConfigCpuMemThreshold(const std::string &value); - Status CheckCacheConfigGpuMemCapacity(const std::string &value); - Status CheckCacheConfigGpuMemThreshold(const std::string &value); - Status CheckCacheConfigCacheInsertData(const std::string &value); + Status + CheckCacheConfigCpuCacheCapacity(const std::string& value); + Status + CheckCacheConfigCpuCacheThreshold(const std::string& value); + Status + CheckCacheConfigGpuCacheCapacity(const std::string& value); + Status + CheckCacheConfigGpuCacheThreshold(const std::string& value); + Status + CheckCacheConfigCacheInsertData(const std::string& value); /* engine config */ - Status CheckEngineConfigBlasThreshold(const std::string &value); - Status CheckEngineConfigOmpThreadNum(const std::string &value); + Status + CheckEngineConfigUseBlasThreshold(const std::string& value); + Status + CheckEngineConfigOmpThreadNum(const std::string& value); /* resource config */ - Status CheckResourceConfigMode(const std::string &value); - Status CheckResourceConfigPool(const std::vector &value); + Status + CheckResourceConfigMode(const std::string& value); + Status + CheckResourceConfigPool(const std::vector& value); - /////////////////////////////////////////////////////////////////////////// - /* server config */ - std::string GetServerConfigStrAddress(); - std::string GetServerConfigStrPort(); - std::string GetServerConfigStrDeployMode(); - std::string GetServerConfigStrTimeZone(); - - /* db config */ - std::string GetDBConfigStrPrimaryPath(); - std::string GetDBConfigStrSecondaryPath(); - std::string GetDBConfigStrBackendUrl(); - std::string GetDBConfigStrArchiveDiskThreshold(); - std::string GetDBConfigStrArchiveDaysThreshold(); - std::string GetDBConfigStrInsertBufferSize(); - std::string GetDBConfigStrBuildIndexGPU(); - - /* metric config */ - std::string GetMetricConfigStrEnableMonitor(); - std::string GetMetricConfigStrCollector(); - std::string GetMetricConfigStrPrometheusPort(); - - /* cache config */ - std::string GetCacheConfigStrCpuMemCapacity(); - std::string GetCacheConfigStrCpuMemThreshold(); - std::string GetCacheConfigStrGpuMemCapacity(); - std::string GetCacheConfigStrGpuMemThreshold(); - std::string GetCacheConfigStrCacheInsertData(); - - /* engine config */ - std::string GetEngineConfigStrBlasThreshold(); - std::string GetEngineConfigStrOmpThreadNum(); - - /* resource config */ - std::string GetResourceConfigStrMode(); + std::string + GetConfigStr(const std::string& parent_key, const std::string& child_key, const std::string& default_value = ""); public: /* server config */ - Status GetServerConfigAddress(std::string &value); - Status GetServerConfigPort(std::string &value); - Status GetServerConfigDeployMode(std::string &value); - Status GetServerConfigTimeZone(std::string &value); + Status + GetServerConfigAddress(std::string& value); + Status + GetServerConfigPort(std::string& value); + Status + GetServerConfigDeployMode(std::string& value); + Status + GetServerConfigTimeZone(std::string& value); /* db config */ - Status GetDBConfigPrimaryPath(std::string &value); - Status GetDBConfigSecondaryPath(std::string &value); - Status GetDBConfigBackendUrl(std::string &value); - Status GetDBConfigArchiveDiskThreshold(int32_t &value); - Status GetDBConfigArchiveDaysThreshold(int32_t &value); - Status GetDBConfigInsertBufferSize(int32_t &value); - Status GetDBConfigBuildIndexGPU(int32_t &value); + Status + GetDBConfigPrimaryPath(std::string& value); + Status + GetDBConfigSecondaryPath(std::string& value); + Status + GetDBConfigBackendUrl(std::string& value); + Status + GetDBConfigArchiveDiskThreshold(int32_t& value); + Status + GetDBConfigArchiveDaysThreshold(int32_t& value); + Status + GetDBConfigInsertBufferSize(int32_t& value); + Status + GetDBConfigBuildIndexGPU(int32_t& value); + Status + GetDBConfigPreloadTable(std::string& value); /* metric config */ - Status GetMetricConfigEnableMonitor(bool &value); - Status GetMetricConfigCollector(std::string &value); - Status GetMetricConfigPrometheusPort(std::string &value); + Status + GetMetricConfigEnableMonitor(bool& value); + Status + GetMetricConfigCollector(std::string& value); + Status + GetMetricConfigPrometheusPort(std::string& value); /* cache config */ - Status GetCacheConfigCpuMemCapacity(int32_t &value); - Status GetCacheConfigCpuMemThreshold(float &value); - Status GetCacheConfigGpuMemCapacity(int32_t &value); - Status GetCacheConfigGpuMemThreshold(float &value); - Status GetCacheConfigCacheInsertData(bool &value); + Status + GetCacheConfigCpuCacheCapacity(int32_t& value); + Status + GetCacheConfigCpuCacheThreshold(float& value); + Status + GetCacheConfigGpuCacheCapacity(int32_t& value); + Status + GetCacheConfigGpuCacheThreshold(float& value); + Status + GetCacheConfigCacheInsertData(bool& value); /* engine config */ - Status GetEngineConfigBlasThreshold(int32_t &value); - Status GetEngineConfigOmpThreadNum(int32_t &value); + Status + GetEngineConfigUseBlasThreshold(int32_t& value); + Status + GetEngineConfigOmpThreadNum(int32_t& value); /* resource config */ - Status GetResourceConfigMode(std::string &value); - Status GetResourceConfigPool(std::vector &value); + Status + GetResourceConfigMode(std::string& value); + Status + GetResourceConfigPool(std::vector& value); public: /* server config */ - Status SetServerConfigAddress(const std::string &value); - Status SetServerConfigPort(const std::string &value); - Status SetServerConfigDeployMode(const std::string &value); - Status SetServerConfigTimeZone(const std::string &value); + Status + SetServerConfigAddress(const std::string& value); + Status + SetServerConfigPort(const std::string& value); + Status + SetServerConfigDeployMode(const std::string& value); + Status + SetServerConfigTimeZone(const std::string& value); /* db config */ - Status SetDBConfigPrimaryPath(const std::string &value); - Status SetDBConfigSecondaryPath(const std::string &value); - Status SetDBConfigBackendUrl(const std::string &value); - Status SetDBConfigArchiveDiskThreshold(const std::string &value); - Status SetDBConfigArchiveDaysThreshold(const std::string &value); - Status SetDBConfigInsertBufferSize(const std::string &value); - Status SetDBConfigBuildIndexGPU(const std::string &value); + Status + SetDBConfigPrimaryPath(const std::string& value); + Status + SetDBConfigSecondaryPath(const std::string& value); + Status + SetDBConfigBackendUrl(const std::string& value); + Status + SetDBConfigArchiveDiskThreshold(const std::string& value); + Status + SetDBConfigArchiveDaysThreshold(const std::string& value); + Status + SetDBConfigInsertBufferSize(const std::string& value); + Status + SetDBConfigBuildIndexGPU(const std::string& value); /* metric config */ - Status SetMetricConfigEnableMonitor(const std::string &value); - Status SetMetricConfigCollector(const std::string &value); - Status SetMetricConfigPrometheusPort(const std::string &value); + Status + SetMetricConfigEnableMonitor(const std::string& value); + Status + SetMetricConfigCollector(const std::string& value); + Status + SetMetricConfigPrometheusPort(const std::string& value); /* cache config */ - Status SetCacheConfigCpuMemCapacity(const std::string &value); - Status SetCacheConfigCpuMemThreshold(const std::string &value); - Status SetCacheConfigGpuMemCapacity(const std::string &value); - Status SetCacheConfigGpuMemThreshold(const std::string &value); - Status SetCacheConfigCacheInsertData(const std::string &value); + Status + SetCacheConfigCpuCacheCapacity(const std::string& value); + Status + SetCacheConfigCpuCacheThreshold(const std::string& value); + Status + SetCacheConfigGpuCacheCapacity(const std::string& value); + Status + SetCacheConfigGpuCacheThreshold(const std::string& value); + Status + SetCacheConfigCacheInsertData(const std::string& value); /* engine config */ - Status SetEngineConfigBlasThreshold(const std::string &value); - Status SetEngineConfigOmpThreadNum(const std::string &value); + Status + SetEngineConfigUseBlasThreshold(const std::string& value); + Status + SetEngineConfigOmpThreadNum(const std::string& value); /* resource config */ - Status SetResourceConfigMode(const std::string &value); + Status + SetResourceConfigMode(const std::string& value); private: std::unordered_map> config_map_; std::mutex mutex_; }; -} // namespace server -} // namespace milvus -} // namespace zilliz - +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/DBWrapper.cpp b/cpp/src/server/DBWrapper.cpp index 73b11f034dba6ee53e093c0d4d1c3158bb1b180a..bb3bd012ab6d3da40471c4a862af85f70f5c0abe 100644 --- a/cpp/src/server/DBWrapper.cpp +++ b/cpp/src/server/DBWrapper.cpp @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #include "server/DBWrapper.h" #include "Config.h" #include "db/DBFactory.h" @@ -23,46 +22,54 @@ #include "utils/Log.h" #include "utils/StringHelpFunctions.h" -#include -#include #include +#include #include +#include +#include -namespace zilliz { namespace milvus { namespace server { -DBWrapper::DBWrapper() { -} - -Status DBWrapper::StartService() { +Status +DBWrapper::StartService() { Config& config = Config::GetInstance(); Status s; - //db config + // db config engine::DBOptions opt; s = config.GetDBConfigBackendUrl(opt.meta_.backend_uri_); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string path; s = config.GetDBConfigPrimaryPath(path); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } opt.meta_.path_ = path + "/db"; std::string db_slave_path; s = config.GetDBConfigSecondaryPath(db_slave_path); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } StringHelpFunctions::SplitStringByDelimeter(db_slave_path, ";", opt.meta_.slave_paths_); // cache config s = config.GetCacheConfigCacheInsertData(opt.insert_cache_immediately_); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } std::string mode; s = config.GetServerConfigDeployMode(mode); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } if (mode == "single") { opt.mode_ = engine::DBOptions::MODE::SINGLE; @@ -71,50 +78,61 @@ Status DBWrapper::StartService() { } else if (mode == "cluster_writable") { opt.mode_ = engine::DBOptions::MODE::CLUSTER_WRITABLE; } else { - std::cerr << - "ERROR: mode specified in server_config must be ['single', 'cluster_readonly', 'cluster_writable']" - << std::endl; + std::cerr << "ERROR: mode specified in server_config must be ['single', 'cluster_readonly', 'cluster_writable']" + << std::endl; kill(0, SIGUSR1); } // engine config int32_t omp_thread; s = config.GetEngineConfigOmpThreadNum(omp_thread); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + if (omp_thread > 0) { omp_set_num_threads(omp_thread); SERVER_LOG_DEBUG << "Specify openmp thread number: " << omp_thread; } else { uint32_t sys_thread_cnt = 8; if (CommonUtil::GetSystemAvailableThreads(sys_thread_cnt)) { - omp_thread = (int32_t)ceil(sys_thread_cnt*0.5); + omp_thread = static_cast(ceil(sys_thread_cnt * 0.5)); omp_set_num_threads(omp_thread); } } - //init faiss global variable - int32_t blas_threshold; - s = config.GetEngineConfigBlasThreshold(blas_threshold); - if (!s.ok()) return s; - faiss::distance_compute_blas_threshold = blas_threshold; + // init faiss global variable + int32_t use_blas_threshold; + s = config.GetEngineConfigUseBlasThreshold(use_blas_threshold); + if (!s.ok()) { + return s; + } + + faiss::distance_compute_blas_threshold = use_blas_threshold; - //set archive config + // set archive config engine::ArchiveConf::CriteriaT criterial; int32_t disk, days; s = config.GetDBConfigArchiveDiskThreshold(disk); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + if (disk > 0) { criterial[engine::ARCHIVE_CONF_DISK] = disk; } s = config.GetDBConfigArchiveDaysThreshold(days); - if (!s.ok()) return s; + if (!s.ok()) { + return s; + } + if (days > 0) { criterial[engine::ARCHIVE_CONF_DAYS] = days; } opt.meta_.archive_conf_.SetCriterias(criterial); - //create db root folder + // create db root folder Status status = CommonUtil::CreateDirectory(opt.meta_.path_); if (!status.ok()) { std::cerr << "ERROR! Failed to create database root path: " << opt.meta_.path_ << std::endl; @@ -129,20 +147,35 @@ Status DBWrapper::StartService() { } } - //create db instance + // create db instance try { db_ = engine::DBFactory::Build(opt); - } catch(std::exception& ex) { + } catch (std::exception& ex) { std::cerr << "ERROR! Failed to open database: " << ex.what() << std::endl; kill(0, SIGUSR1); } db_->Start(); + // preload table + std::string preload_tables; + s = config.GetDBConfigPreloadTable(preload_tables); + if (!s.ok()) { + return s; + } + + s = PreloadTables(preload_tables); + if (!s.ok()) { + std::cerr << "ERROR! Failed to preload tables: " << preload_tables << std::endl; + std::cerr << s.ToString() << std::endl; + kill(0, SIGUSR1); + } + return Status::OK(); } -Status DBWrapper::StopService() { +Status +DBWrapper::StopService() { if (db_) { db_->Stop(); } @@ -150,6 +183,34 @@ Status DBWrapper::StopService() { return Status::OK(); } -} // namespace server -} // namespace milvus -} // namespace zilliz +Status +DBWrapper::PreloadTables(const std::string& preload_tables) { + if (preload_tables.empty()) { + // do nothing + } else if (preload_tables == "*") { + // load all tables + std::vector table_schema_array; + db_->AllTables(table_schema_array); + + for (auto& schema : table_schema_array) { + auto status = db_->PreloadTable(schema.table_id_); + if (!status.ok()) { + return status; + } + } + } else { + std::vector table_names; + StringHelpFunctions::SplitStringByDelimeter(preload_tables, ",", table_names); + for (auto& name : table_names) { + auto status = db_->PreloadTable(name); + if (!status.ok()) { + return status; + } + } + } + + return Status::OK(); +} + +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/DBWrapper.h b/cpp/src/server/DBWrapper.h index 4648f4b59d5566a95df6277091a910c2c3c8e9b3..7016aa8805bf826678934bef880cd43590906301 100644 --- a/cpp/src/server/DBWrapper.h +++ b/cpp/src/server/DBWrapper.h @@ -17,41 +17,49 @@ #pragma once -#include "utils/Status.h" #include "db/DB.h" +#include "utils/Status.h" #include +#include -namespace zilliz { namespace milvus { namespace server { class DBWrapper { private: - DBWrapper(); + DBWrapper() = default; ~DBWrapper() = default; public: - static DBWrapper &GetInstance() { + static DBWrapper& + GetInstance() { static DBWrapper wrapper; return wrapper; } - static engine::DBPtr DB() { + static engine::DBPtr + DB() { return GetInstance().EngineDB(); } - Status StartService(); - Status StopService(); + Status + StartService(); + Status + StopService(); - engine::DBPtr EngineDB() { + engine::DBPtr + EngineDB() { return db_; } + private: + Status + PreloadTables(const std::string& preload_tables); + private: engine::DBPtr db_; }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/Server.cpp b/cpp/src/server/Server.cpp index f0675fc84d2d6edb2b51e9b35b092bf421283f1a..d9aabfc1798c2894950cb2493ba75e78c8081cbc 100644 --- a/cpp/src/server/Server.cpp +++ b/cpp/src/server/Server.cpp @@ -17,41 +17,38 @@ #include "server/Server.h" -#include #include #include #include #include +#include //#include -#include #include +#include -#include "server/grpc_impl/GrpcServer.h" +#include "DBWrapper.h" +#include "metrics/Metrics.h" +#include "scheduler/SchedInst.h" #include "server/Config.h" +#include "server/grpc_impl/GrpcServer.h" #include "utils/Log.h" #include "utils/LogUtil.h" #include "utils/SignalUtil.h" #include "utils/TimeRecorder.h" -#include "metrics/Metrics.h" -#include "scheduler/SchedInst.h" #include "wrapper/KnowhereResource.h" -#include "DBWrapper.h" -namespace zilliz { namespace milvus { namespace server { -Server & +Server& Server::GetInstance() { static Server server; return server; } void -Server::Init(int64_t daemonized, - const std::string &pid_filename, - const std::string &config_filename, - const std::string &log_config_file) { +Server::Init(int64_t daemonized, const std::string& pid_filename, const std::string& config_filename, + const std::string& log_config_file) { daemonized_ = daemonized; pid_filename_ = pid_filename; config_filename_ = config_filename; @@ -66,9 +63,9 @@ Server::Daemonize() { std::cout << "Milvus server run in daemonize mode"; -// std::string log_path(GetLogDirFullPath()); -// log_path += "zdb_server.(INFO/WARNNING/ERROR/CRITICAL)"; -// SERVER_LOG_INFO << "Log will be exported to: " + log_path); + // std::string log_path(GetLogDirFullPath()); + // log_path += "zdb_server.(INFO/WARNNING/ERROR/CRITICAL)"; + // SERVER_LOG_INFO << "Log will be exported to: " + log_path); pid_t pid = 0; @@ -148,7 +145,7 @@ Server::Daemonize() { void Server::Start() { - if (daemonized_) { + if (daemonized_ != 0) { Daemonize(); } @@ -160,7 +157,7 @@ Server::Start() { } /* log path is defined in Config file, so InitLog must be called after LoadConfig */ - Config &config = Config::GetInstance(); + Config& config = Config::GetInstance(); std::string time_zone; Status s = config.GetServerConfigTimeZone(time_zone); if (!s.ok()) { @@ -194,7 +191,7 @@ Server::Start() { StartService(); std::cout << "Milvus server start successfully." << std::endl; - } catch (std::exception &ex) { + } catch (std::exception& ex) { std::cerr << "Milvus server encounter exception: " << ex.what(); } } @@ -233,7 +230,7 @@ Server::Stop() { ErrorCode Server::LoadConfig() { - Config &config = Config::GetInstance(); + Config& config = Config::GetInstance(); Status s = config.LoadConfigFile(config_filename_); if (!s.ok()) { std::cerr << "Failed to load config file: " << config_filename_ << std::endl; @@ -264,6 +261,5 @@ Server::StopService() { engine::KnowhereResource::Finalize(); } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/Server.h b/cpp/src/server/Server.h index 9bd0ce13b5accc597af280d4d327eaff4edffc37..658dbf7c3026cd432d98c46010cb8e364659d7a5 100644 --- a/cpp/src/server/Server.h +++ b/cpp/src/server/Server.h @@ -22,32 +22,37 @@ #include #include -namespace zilliz { namespace milvus { namespace server { class Server { public: - static Server &GetInstance(); + static Server& + GetInstance(); - void Init(int64_t daemonized, - const std::string &pid_filename, - const std::string &config_filename, - const std::string &log_config_file); + void + Init(int64_t daemonized, const std::string& pid_filename, const std::string& config_filename, + const std::string& log_config_file); - void Start(); - void Stop(); + void + Start(); + void + Stop(); private: Server() = default; ~Server() = default; - void Daemonize(); + void + Daemonize(); - ErrorCode LoadConfig(); + ErrorCode + LoadConfig(); - void StartService(); - void StopService(); + void + StartService(); + void + StopService(); private: int64_t daemonized_ = 0; @@ -57,6 +62,5 @@ class Server { std::string log_config_file_; }; // Server -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp b/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp index a64422aab734e91fff21640f7c6ece91beda38d0..a9ee3d77d0d0e761284ab6838ab16a0e1aef281b 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -15,31 +15,27 @@ // specific language governing permissions and limitations // under the License. - #include "server/grpc_impl/GrpcRequestHandler.h" #include "server/grpc_impl/GrpcRequestTask.h" #include "utils/TimeRecorder.h" #include -namespace zilliz { namespace milvus { namespace server { namespace grpc { ::grpc::Status -GrpcRequestHandler::CreateTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableSchema *request, - ::milvus::grpc::Status *response) { +GrpcRequestHandler::CreateTable(::grpc::ServerContext* context, const ::milvus::grpc::TableSchema* request, + ::milvus::grpc::Status* response) { BaseTaskPtr task_ptr = CreateTableTask::Create(request); GrpcRequestScheduler::ExecTask(task_ptr, response); return ::grpc::Status::OK; } ::grpc::Status -GrpcRequestHandler::HasTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::BoolReply *response) { +GrpcRequestHandler::HasTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::BoolReply* response) { bool has_table = false; BaseTaskPtr task_ptr = HasTableTask::Create(request->table_name(), has_table); ::milvus::grpc::Status grpc_status; @@ -51,27 +47,24 @@ GrpcRequestHandler::HasTable(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::DropTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::Status *response) { +GrpcRequestHandler::DropTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::Status* response) { BaseTaskPtr task_ptr = DropTableTask::Create(request->table_name()); GrpcRequestScheduler::ExecTask(task_ptr, response); return ::grpc::Status::OK; } ::grpc::Status -GrpcRequestHandler::CreateIndex(::grpc::ServerContext *context, - const ::milvus::grpc::IndexParam *request, - ::milvus::grpc::Status *response) { +GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::grpc::IndexParam* request, + ::milvus::grpc::Status* response) { BaseTaskPtr task_ptr = CreateIndexTask::Create(request); GrpcRequestScheduler::ExecTask(task_ptr, response); return ::grpc::Status::OK; } ::grpc::Status -GrpcRequestHandler::Insert(::grpc::ServerContext *context, - const ::milvus::grpc::InsertParam *request, - ::milvus::grpc::VectorIds *response) { +GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc::InsertParam* request, + ::milvus::grpc::VectorIds* response) { BaseTaskPtr task_ptr = InsertTask::Create(request, response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -81,9 +74,8 @@ GrpcRequestHandler::Insert(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::Search(::grpc::ServerContext *context, - const ::milvus::grpc::SearchParam *request, - ::milvus::grpc::TopKQueryResultList *response) { +GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, + ::milvus::grpc::TopKQueryResultList* response) { std::vector file_id_array; BaseTaskPtr task_ptr = SearchTask::Create(request, file_id_array, response); ::milvus::grpc::Status grpc_status; @@ -94,14 +86,13 @@ GrpcRequestHandler::Search(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::SearchInFiles(::grpc::ServerContext *context, - const ::milvus::grpc::SearchInFilesParam *request, - ::milvus::grpc::TopKQueryResultList *response) { +GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, + ::milvus::grpc::TopKQueryResultList* response) { std::vector file_id_array; for (int i = 0; i < request->file_id_array_size(); i++) { file_id_array.push_back(request->file_id_array(i)); } - ::milvus::grpc::SearchInFilesParam *request_mutable = const_cast<::milvus::grpc::SearchInFilesParam *>(request); + ::milvus::grpc::SearchInFilesParam* request_mutable = const_cast<::milvus::grpc::SearchInFilesParam*>(request); BaseTaskPtr task_ptr = SearchTask::Create(request_mutable->mutable_search_param(), file_id_array, response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -111,9 +102,8 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::DescribeTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::TableSchema *response) { +GrpcRequestHandler::DescribeTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::TableSchema* response) { BaseTaskPtr task_ptr = DescribeTableTask::Create(request->table_name(), response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -123,9 +113,8 @@ GrpcRequestHandler::DescribeTable(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::CountTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::TableRowCount *response) { +GrpcRequestHandler::CountTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::TableRowCount* response) { int64_t row_count = 0; BaseTaskPtr task_ptr = CountTableTask::Create(request->table_name(), row_count); ::milvus::grpc::Status grpc_status; @@ -137,9 +126,8 @@ GrpcRequestHandler::CountTable(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::ShowTables(::grpc::ServerContext *context, - const ::milvus::grpc::Command *request, - ::milvus::grpc::TableNameList *response) { +GrpcRequestHandler::ShowTables(::grpc::ServerContext* context, const ::milvus::grpc::Command* request, + ::milvus::grpc::TableNameList* response) { BaseTaskPtr task_ptr = ShowTablesTask::Create(response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -149,9 +137,8 @@ GrpcRequestHandler::ShowTables(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::Cmd(::grpc::ServerContext *context, - const ::milvus::grpc::Command *request, - ::milvus::grpc::StringReply *response) { +GrpcRequestHandler::Cmd(::grpc::ServerContext* context, const ::milvus::grpc::Command* request, + ::milvus::grpc::StringReply* response) { std::string result; BaseTaskPtr task_ptr = CmdTask::Create(request->cmd(), result); ::milvus::grpc::Status grpc_status; @@ -163,9 +150,8 @@ GrpcRequestHandler::Cmd(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::DeleteByRange(::grpc::ServerContext *context, - const ::milvus::grpc::DeleteByRangeParam *request, - ::milvus::grpc::Status *response) { +GrpcRequestHandler::DeleteByRange(::grpc::ServerContext* context, const ::milvus::grpc::DeleteByRangeParam* request, + ::milvus::grpc::Status* response) { BaseTaskPtr task_ptr = DeleteByRangeTask::Create(request); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -175,9 +161,8 @@ GrpcRequestHandler::DeleteByRange(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::PreloadTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::Status *response) { +GrpcRequestHandler::PreloadTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::Status* response) { BaseTaskPtr task_ptr = PreloadTableTask::Create(request->table_name()); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -187,9 +172,8 @@ GrpcRequestHandler::PreloadTable(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::DescribeIndex(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::IndexParam *response) { +GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::IndexParam* response) { BaseTaskPtr task_ptr = DescribeIndexTask::Create(request->table_name(), response); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -199,9 +183,8 @@ GrpcRequestHandler::DescribeIndex(::grpc::ServerContext *context, } ::grpc::Status -GrpcRequestHandler::DropIndex(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::Status *response) { +GrpcRequestHandler::DropIndex(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::Status* response) { BaseTaskPtr task_ptr = DropIndexTask::Create(request->table_name()); ::milvus::grpc::Status grpc_status; GrpcRequestScheduler::ExecTask(task_ptr, &grpc_status); @@ -210,7 +193,6 @@ GrpcRequestHandler::DropIndex(::grpc::ServerContext *context, return ::grpc::Status::OK; } -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcRequestHandler.h b/cpp/src/server/grpc_impl/GrpcRequestHandler.h index 549e1d9d3519eadf80912807459fd12d64880f16..1a9b5911540c6e86ff1d5053af77d1c85dc53239 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestHandler.h +++ b/cpp/src/server/grpc_impl/GrpcRequestHandler.h @@ -23,7 +23,6 @@ #include "grpc/gen-milvus/milvus.grpc.pb.h" #include "grpc/gen-status/status.pb.h" -namespace zilliz { namespace milvus { namespace server { namespace grpc { @@ -45,8 +44,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param context */ ::grpc::Status - CreateTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableSchema *request, ::milvus::grpc::Status *response) override; + CreateTable(::grpc::ServerContext* context, const ::milvus::grpc::TableSchema* request, + ::milvus::grpc::Status* response) override; /** * @brief Test table existence method @@ -64,8 +63,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param context */ ::grpc::Status - HasTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, ::milvus::grpc::BoolReply *response) override; + HasTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::BoolReply* response) override; /** * @brief Drop table method @@ -83,8 +82,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param context */ ::grpc::Status - DropTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, ::milvus::grpc::Status *response) override; + DropTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::Status* response) override; /** * @brief build index by table method @@ -102,8 +101,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param context */ ::grpc::Status - CreateIndex(::grpc::ServerContext *context, - const ::milvus::grpc::IndexParam *request, ::milvus::grpc::Status *response) override; + CreateIndex(::grpc::ServerContext* context, const ::milvus::grpc::IndexParam* request, + ::milvus::grpc::Status* response) override; /** * @brief Insert vector array to table @@ -121,9 +120,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param response */ ::grpc::Status - Insert(::grpc::ServerContext *context, - const ::milvus::grpc::InsertParam *request, - ::milvus::grpc::VectorIds *response) override; + Insert(::grpc::ServerContext* context, const ::milvus::grpc::InsertParam* request, + ::milvus::grpc::VectorIds* response) override; /** * @brief Query vector @@ -146,34 +144,32 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param writer */ ::grpc::Status - Search(::grpc::ServerContext *context, - const ::milvus::grpc::SearchParam *request, - ::milvus::grpc::TopKQueryResultList *response) override; + Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, + ::milvus::grpc::TopKQueryResultList* response) override; /** - * @brief Internal use query interface - * - * This method is used to query vector in specified files. - * - * @param context, add context for every RPC - * @param request: - * file_id_array, specified files id array, queried. - * query_record_array, all vector are going to be queried. - * query_range_array, optional ranges for conditional search. If not specified, search whole table - * topk, how many similarity vectors will be searched. - * - * @param writer, write query result array. - * - * @return status - * - * @param context - * @param request - * @param writer - */ + * @brief Internal use query interface + * + * This method is used to query vector in specified files. + * + * @param context, add context for every RPC + * @param request: + * file_id_array, specified files id array, queried. + * query_record_array, all vector are going to be queried. + * query_range_array, optional ranges for conditional search. If not specified, search whole table + * topk, how many similarity vectors will be searched. + * + * @param writer, write query result array. + * + * @return status + * + * @param context + * @param request + * @param writer + */ ::grpc::Status - SearchInFiles(::grpc::ServerContext *context, - const ::milvus::grpc::SearchInFilesParam *request, - ::milvus::grpc::TopKQueryResultList *response) override; + SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, + ::milvus::grpc::TopKQueryResultList* response) override; /** * @brief Get table schema @@ -191,9 +187,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param response */ ::grpc::Status - DescribeTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::TableSchema *response) override; + DescribeTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::TableSchema* response) override; /** * @brief Get table row count @@ -211,9 +206,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param context */ ::grpc::Status - CountTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::TableRowCount *response) override; + CountTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::TableRowCount* response) override; /** * @brief List all tables in database @@ -231,9 +225,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param writer */ ::grpc::Status - ShowTables(::grpc::ServerContext *context, - const ::milvus::grpc::Command *request, - ::milvus::grpc::TableNameList *table_name_list) override; + ShowTables(::grpc::ServerContext* context, const ::milvus::grpc::Command* request, + ::milvus::grpc::TableNameList* response) override; /** * @brief Give the server status @@ -251,9 +244,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param response */ ::grpc::Status - Cmd(::grpc::ServerContext *context, - const ::milvus::grpc::Command *request, - ::milvus::grpc::StringReply *response) override; + Cmd(::grpc::ServerContext* context, const ::milvus::grpc::Command* request, + ::milvus::grpc::StringReply* response) override; /** * @brief delete table by range @@ -270,9 +262,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param response */ ::grpc::Status - DeleteByRange(::grpc::ServerContext *context, - const ::milvus::grpc::DeleteByRangeParam *request, - ::milvus::grpc::Status *response) override; + DeleteByRange(::grpc::ServerContext* context, const ::milvus::grpc::DeleteByRangeParam* request, + ::milvus::grpc::Status* response) override; /** * @brief preload table @@ -289,9 +280,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param response */ ::grpc::Status - PreloadTable(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::Status *response) override; + PreloadTable(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::Status* response) override; /** * @brief Describe index @@ -308,9 +298,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param response */ ::grpc::Status - DescribeIndex(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::IndexParam *response) override; + DescribeIndex(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::IndexParam* response) override; /** * @brief Drop index @@ -327,13 +316,10 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service { * @param response */ ::grpc::Status - DropIndex(::grpc::ServerContext *context, - const ::milvus::grpc::TableName *request, - ::milvus::grpc::Status *response) override; + DropIndex(::grpc::ServerContext* context, const ::milvus::grpc::TableName* request, + ::milvus::grpc::Status* response) override; }; -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz - +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcRequestScheduler.cpp b/cpp/src/server/grpc_impl/GrpcRequestScheduler.cpp index 4c58195f37db9e0485189c4f901dbe0280c9a25a..ac35f82947a88ae7eaf89449d6a7aa4c5e28c350 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestScheduler.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestScheduler.cpp @@ -22,7 +22,6 @@ #include -namespace zilliz { namespace milvus { namespace server { namespace grpc { @@ -37,7 +36,6 @@ ErrorMap(ErrorCode code) { {SERVER_INVALID_ARGUMENT, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_FILE_NOT_FOUND, ::milvus::grpc::ErrorCode::FILE_NOT_FOUND}, {SERVER_NOT_IMPLEMENT, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR}, - {SERVER_BLOCKING_QUEUE_EMPTY, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR}, {SERVER_CANNOT_CREATE_FOLDER, ::milvus::grpc::ErrorCode::CANNOT_CREATE_FOLDER}, {SERVER_CANNOT_CREATE_FILE, ::milvus::grpc::ErrorCode::CANNOT_CREATE_FILE}, {SERVER_CANNOT_DELETE_FOLDER, ::milvus::grpc::ErrorCode::CANNOT_DELETE_FOLDER}, @@ -58,7 +56,7 @@ ErrorMap(ErrorCode code) { {SERVER_INVALID_INDEX_FILE_SIZE, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT}, {SERVER_ILLEGAL_VECTOR_ID, ::milvus::grpc::ErrorCode::ILLEGAL_VECTOR_ID}, {SERVER_ILLEGAL_SEARCH_RESULT, ::milvus::grpc::ErrorCode::ILLEGAL_SEARCH_RESULT}, - {SERVER_CACHE_ERROR, ::milvus::grpc::ErrorCode::CACHE_FAILED}, + {SERVER_CACHE_FULL, ::milvus::grpc::ErrorCode::CACHE_FAILED}, {DB_META_TRANSACTION_FAILED, ::milvus::grpc::ErrorCode::META_FAILED}, {SERVER_BUILD_INDEX_ERROR, ::milvus::grpc::ErrorCode::BUILD_INDEX_ERROR}, {SERVER_OUT_OF_MEMORY, ::milvus::grpc::ErrorCode::OUT_OF_MEMORY}, @@ -70,13 +68,11 @@ ErrorMap(ErrorCode code) { return ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR; } } -} // namespace +} // namespace //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -GrpcBaseTask::GrpcBaseTask(const std::string &task_group, bool async) - : task_group_(task_group), - async_(async), - done_(false) { +GrpcBaseTask::GrpcBaseTask(const std::string& task_group, bool async) + : task_group_(task_group), async_(async), done_(false) { } GrpcBaseTask::~GrpcBaseTask() { @@ -97,7 +93,7 @@ GrpcBaseTask::Done() { } Status -GrpcBaseTask::SetStatus(ErrorCode error_code, const std::string &error_msg) { +GrpcBaseTask::SetStatus(ErrorCode error_code, const std::string& error_msg) { status_ = Status(error_code, error_msg); SERVER_LOG_ERROR << error_msg; return status_; @@ -106,16 +102,13 @@ GrpcBaseTask::SetStatus(ErrorCode error_code, const std::string &error_msg) { Status GrpcBaseTask::WaitToFinish() { std::unique_lock lock(finish_mtx_); - finish_cond_.wait(lock, [this] { - return done_; - }); + finish_cond_.wait(lock, [this] { return done_; }); return status_; } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -GrpcRequestScheduler::GrpcRequestScheduler() - : stopped_(false) { +GrpcRequestScheduler::GrpcRequestScheduler() : stopped_(false) { Start(); } @@ -124,17 +117,17 @@ GrpcRequestScheduler::~GrpcRequestScheduler() { } void -GrpcRequestScheduler::ExecTask(BaseTaskPtr &task_ptr, ::milvus::grpc::Status *grpc_status) { +GrpcRequestScheduler::ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status* grpc_status) { if (task_ptr == nullptr) { return; } - GrpcRequestScheduler &scheduler = GrpcRequestScheduler::GetInstance(); + GrpcRequestScheduler& scheduler = GrpcRequestScheduler::GetInstance(); scheduler.ExecuteTask(task_ptr); if (!task_ptr->IsAsync()) { task_ptr->WaitToFinish(); - const Status &status = task_ptr->status(); + const Status& status = task_ptr->status(); if (!status.ok()) { grpc_status->set_reason(status.message()); grpc_status->set_error_code(ErrorMap(status.code())); @@ -178,7 +171,7 @@ GrpcRequestScheduler::Stop() { } Status -GrpcRequestScheduler::ExecuteTask(const BaseTaskPtr &task_ptr) { +GrpcRequestScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { if (task_ptr == nullptr) { return Status::OK(); } @@ -190,10 +183,10 @@ GrpcRequestScheduler::ExecuteTask(const BaseTaskPtr &task_ptr) { } if (task_ptr->IsAsync()) { - return Status::OK(); //async execution, caller need to call WaitToFinish at somewhere + return Status::OK(); // async execution, caller need to call WaitToFinish at somewhere } - return task_ptr->WaitToFinish();//sync execution + return task_ptr->WaitToFinish(); // sync execution } void @@ -206,7 +199,7 @@ GrpcRequestScheduler::TakeTaskToExecute(TaskQueuePtr task_queue) { BaseTaskPtr task = task_queue->Take(); if (task == nullptr) { SERVER_LOG_ERROR << "Take null from task queue, stop thread"; - break;//stop the thread + break; // stop the thread } try { @@ -214,14 +207,14 @@ GrpcRequestScheduler::TakeTaskToExecute(TaskQueuePtr task_queue) { if (!status.ok()) { SERVER_LOG_ERROR << "Task failed with code: " << status.ToString(); } - } catch (std::exception &ex) { + } catch (std::exception& ex) { SERVER_LOG_ERROR << "Task failed to execute: " << ex.what(); } } } Status -GrpcRequestScheduler::PutTaskToQueue(const BaseTaskPtr &task_ptr) { +GrpcRequestScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { std::lock_guard lock(queue_mtx_); std::string group_name = task_ptr->TaskGroup(); @@ -232,7 +225,7 @@ GrpcRequestScheduler::PutTaskToQueue(const BaseTaskPtr &task_ptr) { queue->Put(task_ptr); task_groups_.insert(std::make_pair(group_name, queue)); - //start a thread + // start a thread ThreadPtr thread = std::make_shared(&GrpcRequestScheduler::TakeTaskToExecute, this, queue); execute_threads_.push_back(thread); SERVER_LOG_INFO << "Create new thread for task group: " << group_name; @@ -241,7 +234,6 @@ GrpcRequestScheduler::PutTaskToQueue(const BaseTaskPtr &task_ptr) { return Status::OK(); } -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcRequestScheduler.h b/cpp/src/server/grpc_impl/GrpcRequestScheduler.h index df5357a4bbad33ab31d3b79e60fb758da905723f..802d247fb5bbfb9d4becdaf3018b17f3e2cef49a 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestScheduler.h +++ b/cpp/src/server/grpc_impl/GrpcRequestScheduler.h @@ -17,51 +17,58 @@ #pragma once -#include "utils/Status.h" -#include "utils/BlockingQueue.h" #include "grpc/gen-status/status.grpc.pb.h" #include "grpc/gen-status/status.pb.h" +#include "utils/BlockingQueue.h" +#include "utils/Status.h" #include -#include -#include #include #include +#include +#include -namespace zilliz { namespace milvus { namespace server { namespace grpc { class GrpcBaseTask { protected: - explicit GrpcBaseTask(const std::string &task_group, bool async = false); + explicit GrpcBaseTask(const std::string& task_group, bool async = false); virtual ~GrpcBaseTask(); public: - Status Execute(); + Status + Execute(); - void Done(); + void + Done(); - Status WaitToFinish(); + Status + WaitToFinish(); - std::string TaskGroup() const { + std::string + TaskGroup() const { return task_group_; } - const Status &status() const { + const Status& + status() const { return status_; } - bool IsAsync() const { + bool + IsAsync() const { return async_; } protected: - virtual Status OnExecute() = 0; + virtual Status + OnExecute() = 0; - Status SetStatus(ErrorCode error_code, const std::string &msg); + Status + SetStatus(ErrorCode error_code, const std::string& error_msg); protected: mutable std::mutex finish_mtx_; @@ -80,27 +87,34 @@ using ThreadPtr = std::shared_ptr; class GrpcRequestScheduler { public: - static GrpcRequestScheduler &GetInstance() { + static GrpcRequestScheduler& + GetInstance() { static GrpcRequestScheduler scheduler; return scheduler; } - void Start(); + void + Start(); - void Stop(); + void + Stop(); - Status ExecuteTask(const BaseTaskPtr &task_ptr); + Status + ExecuteTask(const BaseTaskPtr& task_ptr); - static void ExecTask(BaseTaskPtr &task_ptr, ::milvus::grpc::Status *grpc_status); + static void + ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status* grpc_status); protected: GrpcRequestScheduler(); virtual ~GrpcRequestScheduler(); - void TakeTaskToExecute(TaskQueuePtr task_queue); + void + TakeTaskToExecute(TaskQueuePtr task_queue); - Status PutTaskToQueue(const BaseTaskPtr &task_ptr); + Status + PutTaskToQueue(const BaseTaskPtr& task_ptr); private: mutable std::mutex queue_mtx_; @@ -112,7 +126,6 @@ class GrpcRequestScheduler { bool stopped_; }; -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp index 5702b80bec5a78ee1d65c423ea90e7daec3b6f7a..1279cbac9f7bf48c7d55ebf56e9a46d36f5acd1f 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp @@ -19,32 +19,33 @@ #include #include -#include #include +#include //#include -#include "server/Server.h" +#include "../../../version.h" +#include "GrpcServer.h" +#include "db/Utils.h" +#include "scheduler/SchedInst.h" #include "server/DBWrapper.h" +#include "server/Server.h" #include "utils/CommonUtil.h" #include "utils/Log.h" #include "utils/TimeRecorder.h" #include "utils/ValidationUtil.h" -#include "GrpcServer.h" -#include "db/Utils.h" -#include "scheduler/SchedInst.h" -#include "../../../version.h" -namespace zilliz { namespace milvus { namespace server { namespace grpc { -static const char *DQL_TASK_GROUP = "dql"; -static const char *DDL_DML_TASK_GROUP = "ddl_dml"; -static const char *PING_TASK_GROUP = "ping"; +static const char* DQL_TASK_GROUP = "dql"; +static const char* DDL_DML_TASK_GROUP = "ddl_dml"; +static const char* PING_TASK_GROUP = "ping"; -using DB_META = zilliz::milvus::engine::meta::Meta; -using DB_DATE = zilliz::milvus::engine::meta::DateT; +constexpr int64_t DAY_SECONDS = 24 * 60 * 60; + +using DB_META = milvus::engine::meta::Meta; +using DB_DATE = milvus::engine::meta::DateT; namespace { engine::EngineType @@ -79,13 +80,10 @@ IndexType(engine::EngineType type) { return map_type[type]; } -constexpr int64_t DAY_SECONDS = 24 * 60 * 60; - Status -ConvertTimeRangeToDBDates(const std::vector<::milvus::grpc::Range> &range_array, - std::vector &dates) { +ConvertTimeRangeToDBDates(const std::vector<::milvus::grpc::Range>& range_array, std::vector& dates) { dates.clear(); - for (auto &range : range_array) { + for (auto& range : range_array) { time_t tt_start, tt_end; tm tm_start, tm_end; if (!CommonUtil::TimeStrToTime(range.start_value(), tt_start, tm_start)) { @@ -96,37 +94,34 @@ ConvertTimeRangeToDBDates(const std::vector<::milvus::grpc::Range> &range_array, return Status(SERVER_INVALID_TIME_RANGE, "Invalid time range: " + range.start_value()); } - int64_t days = (tt_end > tt_start) ? (tt_end - tt_start) / DAY_SECONDS : (tt_start - tt_end) / - DAY_SECONDS; - if (days == 0) { + int64_t days = (tt_end - tt_start) / DAY_SECONDS; + if (days <= 0) { return Status(SERVER_INVALID_TIME_RANGE, - "Invalid time range: " + range.start_value() + " to " + range.end_value()); + "Invalid time range: The start-date should be smaller than end-date!"); } - //range: [start_day, end_day) + // range: [start_day, end_day) for (int64_t i = 0; i < days; i++) { time_t tt_day = tt_start + DAY_SECONDS * i; tm tm_day; CommonUtil::ConvertTime(tt_day, tm_day); - int64_t date = tm_day.tm_year * 10000 + tm_day.tm_mon * 100 + - tm_day.tm_mday;//according to db logic + int64_t date = tm_day.tm_year * 10000 + tm_day.tm_mon * 100 + tm_day.tm_mday; // according to db logic dates.push_back(date); } } return Status::OK(); } -} // namespace +} // namespace //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema *schema) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - schema_(schema) { +CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema* schema) + : GrpcBaseTask(DDL_DML_TASK_GROUP), schema_(schema) { } BaseTaskPtr -CreateTableTask::Create(const ::milvus::grpc::TableSchema *schema) { +CreateTableTask::Create(const ::milvus::grpc::TableSchema* schema) { if (schema == nullptr) { SERVER_LOG_ERROR << "grpc input is null!"; return nullptr; @@ -139,7 +134,7 @@ CreateTableTask::OnExecute() { TimeRecorder rc("CreateTableTask"); try { - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(schema_->table_name()); if (!status.ok()) { return status; @@ -160,23 +155,23 @@ CreateTableTask::OnExecute() { return status; } - //step 2: construct table schema + // step 2: construct table schema engine::meta::TableSchema table_info; table_info.table_id_ = schema_->table_name(); - table_info.dimension_ = (uint16_t) schema_->dimension(); + table_info.dimension_ = static_cast(schema_->dimension()); table_info.index_file_size_ = schema_->index_file_size(); table_info.metric_type_ = schema_->metric_type(); - //step 3: create table + // step 3: create table status = DBWrapper::DB()->CreateTable(table_info); if (!status.ok()) { - //table could exist + // table could exist if (status.code() == DB_ALREADY_EXIST) { return Status(SERVER_INVALID_TABLE_NAME, status.message()); } return status; } - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -186,14 +181,12 @@ CreateTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema *schema) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_(table_name), - schema_(schema) { +DescribeTableTask::DescribeTableTask(const std::string& table_name, ::milvus::grpc::TableSchema* schema) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name), schema_(schema) { } BaseTaskPtr -DescribeTableTask::Create(const std::string &table_name, ::milvus::grpc::TableSchema *schema) { +DescribeTableTask::Create(const std::string& table_name, ::milvus::grpc::TableSchema* schema) { return std::shared_ptr(new DescribeTableTask(table_name, schema)); } @@ -202,13 +195,13 @@ DescribeTableTask::OnExecute() { TimeRecorder rc("DescribeTableTask"); try { - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; } - //step 2: get table info + // step 2: get table info engine::meta::TableSchema table_info; table_info.table_id_ = table_name_; status = DBWrapper::DB()->DescribeTable(table_info); @@ -220,7 +213,7 @@ DescribeTableTask::OnExecute() { schema_->set_dimension(table_info.dimension_); schema_->set_index_file_size(table_info.index_file_size_); schema_->set_metric_type(table_info.metric_type_); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -230,13 +223,12 @@ DescribeTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -CreateIndexTask::CreateIndexTask(const ::milvus::grpc::IndexParam *index_param) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - index_param_(index_param) { +CreateIndexTask::CreateIndexTask(const ::milvus::grpc::IndexParam* index_param) + : GrpcBaseTask(DDL_DML_TASK_GROUP), index_param_(index_param) { } BaseTaskPtr -CreateIndexTask::Create(const ::milvus::grpc::IndexParam *index_param) { +CreateIndexTask::Create(const ::milvus::grpc::IndexParam* index_param) { if (index_param == nullptr) { SERVER_LOG_ERROR << "grpc input is null!"; return nullptr; @@ -249,7 +241,7 @@ CreateIndexTask::OnExecute() { try { TimeRecorder rc("CreateIndexTask"); - //step 1: check arguments + // step 1: check arguments std::string table_name_ = index_param_->table_name(); auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { @@ -266,7 +258,7 @@ CreateIndexTask::OnExecute() { return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); } - auto &grpc_index = index_param_->index(); + auto& grpc_index = index_param_->index(); status = ValidationUtil::ValidateTableIndexType(grpc_index.index_type()); if (!status.ok()) { return status; @@ -277,7 +269,7 @@ CreateIndexTask::OnExecute() { return status; } - //step 2: check table existence + // step 2: check table existence engine::TableIndex index; index.engine_type_ = grpc_index.index_type(); index.nlist_ = grpc_index.nlist(); @@ -287,7 +279,7 @@ CreateIndexTask::OnExecute() { } rc.ElapseFromBegin("totally cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -295,14 +287,12 @@ CreateIndexTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -HasTableTask::HasTableTask(const std::string &table_name, bool &has_table) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_(table_name), - has_table_(has_table) { +HasTableTask::HasTableTask(const std::string& table_name, bool& has_table) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name), has_table_(has_table) { } BaseTaskPtr -HasTableTask::Create(const std::string &table_name, bool &has_table) { +HasTableTask::Create(const std::string& table_name, bool& has_table) { return std::shared_ptr(new HasTableTask(table_name, has_table)); } @@ -311,20 +301,20 @@ HasTableTask::OnExecute() { try { TimeRecorder rc("HasTableTask"); - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; } - //step 2: check table existence + // step 2: check table existence status = DBWrapper::DB()->HasTable(table_name_, has_table_); if (!status.ok()) { return status; } rc.ElapseFromBegin("totally cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -332,13 +322,12 @@ HasTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -DropTableTask::DropTableTask(const std::string &table_name) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_(table_name) { +DropTableTask::DropTableTask(const std::string& table_name) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name) { } BaseTaskPtr -DropTableTask::Create(const std::string &table_name) { +DropTableTask::Create(const std::string& table_name) { return std::shared_ptr(new DropTableTask(table_name)); } @@ -347,13 +336,13 @@ DropTableTask::OnExecute() { try { TimeRecorder rc("DropTableTask"); - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; } - //step 2: check table existence + // step 2: check table existence engine::meta::TableSchema table_info; table_info.table_id_ = table_name_; status = DBWrapper::DB()->DescribeTable(table_info); @@ -367,7 +356,7 @@ DropTableTask::OnExecute() { rc.ElapseFromBegin("check validation"); - //step 3: Drop table + // step 3: Drop table std::vector dates; status = DBWrapper::DB()->DeleteTable(table_name_, dates); if (!status.ok()) { @@ -375,7 +364,7 @@ DropTableTask::OnExecute() { } rc.ElapseFromBegin("total cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -383,13 +372,12 @@ DropTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -ShowTablesTask::ShowTablesTask(::milvus::grpc::TableNameList *table_name_list) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_list_(table_name_list) { +ShowTablesTask::ShowTablesTask(::milvus::grpc::TableNameList* table_name_list) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_list_(table_name_list) { } BaseTaskPtr -ShowTablesTask::Create(::milvus::grpc::TableNameList *table_name_list) { +ShowTablesTask::Create(::milvus::grpc::TableNameList* table_name_list) { return std::shared_ptr(new ShowTablesTask(table_name_list)); } @@ -401,23 +389,19 @@ ShowTablesTask::OnExecute() { return statuts; } - for (auto &schema : schema_array) { + for (auto& schema : schema_array) { table_name_list_->add_table_names(schema.table_id_); } return Status::OK(); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -InsertTask::InsertTask(const ::milvus::grpc::InsertParam *insert_param, - ::milvus::grpc::VectorIds *record_ids) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - insert_param_(insert_param), - record_ids_(record_ids) { +InsertTask::InsertTask(const ::milvus::grpc::InsertParam* insert_param, ::milvus::grpc::VectorIds* record_ids) + : GrpcBaseTask(DDL_DML_TASK_GROUP), insert_param_(insert_param), record_ids_(record_ids) { } BaseTaskPtr -InsertTask::Create(const ::milvus::grpc::InsertParam *insert_param, - ::milvus::grpc::VectorIds *record_ids) { +InsertTask::Create(const ::milvus::grpc::InsertParam* insert_param, ::milvus::grpc::VectorIds* record_ids) { if (insert_param == nullptr) { SERVER_LOG_ERROR << "grpc input is null!"; return nullptr; @@ -430,7 +414,7 @@ InsertTask::OnExecute() { try { TimeRecorder rc("InsertVectorTask"); - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(insert_param_->table_name()); if (!status.ok()) { return status; @@ -441,35 +425,33 @@ InsertTask::OnExecute() { if (!record_ids_->vector_id_array().empty()) { if (record_ids_->vector_id_array().size() != insert_param_->row_record_array_size()) { - return Status(SERVER_ILLEGAL_VECTOR_ID, - "Size of vector ids is not equal to row record array size"); + return Status(SERVER_ILLEGAL_VECTOR_ID, "Size of vector ids is not equal to row record array size"); } } - //step 2: check table existence + // step 2: check table existence engine::meta::TableSchema table_info; table_info.table_id_ = insert_param_->table_name(); status = DBWrapper::DB()->DescribeTable(table_info); if (!status.ok()) { if (status.code() == DB_NOT_FOUND) { - return Status(SERVER_TABLE_NOT_EXIST, - "Table " + insert_param_->table_name() + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, "Table " + insert_param_->table_name() + " not exists"); } else { return status; } } - //step 3: check table flag - //all user provide id, or all internal id + // step 3: check table flag + // all user provide id, or all internal id bool user_provide_ids = !insert_param_->row_id_array().empty(); - //user already provided id before, all insert action require user id - if ((table_info.flag_ & engine::meta::FLAG_MASK_HAS_USERID) && !user_provide_ids) { + // user already provided id before, all insert action require user id + if ((table_info.flag_ & engine::meta::FLAG_MASK_HAS_USERID) != 0 && !user_provide_ids) { return Status(SERVER_ILLEGAL_VECTOR_ID, "Table vector ids are user defined, please provide id for this batch"); } - //user didn't provided id before, no need to provide user id - if ((table_info.flag_ & engine::meta::FLAG_MASK_NO_USERID) && user_provide_ids) { + // user didn't provided id before, no need to provide user id + if ((table_info.flag_ & engine::meta::FLAG_MASK_NO_USERID) != 0 && user_provide_ids) { return Status(SERVER_ILLEGAL_VECTOR_ID, "Table vector ids are auto generated, no need to provide id for this batch"); } @@ -477,15 +459,15 @@ InsertTask::OnExecute() { rc.RecordSection("check validation"); #ifdef MILVUS_ENABLE_PROFILING - std::string fname = "/tmp/insert_" + std::to_string(this->insert_param_->row_record_array_size()) - + ".profiling"; + std::string fname = + "/tmp/insert_" + std::to_string(this->insert_param_->row_record_array_size()) + ".profiling"; ProfilerStart(fname.c_str()); #endif - //step 4: prepare float data + // step 4: prepare float data std::vector vec_f(insert_param_->row_record_array_size() * table_info.dimension_, 0); - // TODO: change to one dimension array in protobuf or use multiple-thread to copy the data + // TODO(yk): change to one dimension array or use multiple-thread to copy the data for (size_t i = 0; i < insert_param_->row_record_array_size(); i++) { if (insert_param_->row_record_array(i).vector_data().empty()) { return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array data is empty"); @@ -493,25 +475,23 @@ InsertTask::OnExecute() { uint64_t vec_dim = insert_param_->row_record_array(i).vector_data().size(); if (vec_dim != table_info.dimension_) { ErrorCode error_code = SERVER_INVALID_VECTOR_DIMENSION; - std::string error_msg = "Invalid row record dimension: " + std::to_string(vec_dim) - + " vs. table dimension:" + - std::to_string(table_info.dimension_); + std::string error_msg = "Invalid row record dimension: " + std::to_string(vec_dim) + + " vs. table dimension:" + std::to_string(table_info.dimension_); return Status(error_code, error_msg); } - memcpy(&vec_f[i * table_info.dimension_], - insert_param_->row_record_array(i).vector_data().data(), + memcpy(&vec_f[i * table_info.dimension_], insert_param_->row_record_array(i).vector_data().data(), table_info.dimension_ * sizeof(float)); } rc.ElapseFromBegin("prepare vectors data"); - //step 5: insert vectors - auto vec_count = (uint64_t) insert_param_->row_record_array_size(); + // step 5: insert vectors + auto vec_count = static_cast(insert_param_->row_record_array_size()); std::vector vec_ids(insert_param_->row_id_array_size(), 0); if (!insert_param_->row_id_array().empty()) { - const int64_t *src_data = insert_param_->row_id_array().data(); - int64_t *target_data = vec_ids.data(); - memcpy(target_data, src_data, (size_t) (sizeof(int64_t) * insert_param_->row_id_array_size())); + const int64_t* src_data = insert_param_->row_id_array().data(); + int64_t* target_data = vec_ids.data(); + memcpy(target_data, src_data, static_cast(sizeof(int64_t) * insert_param_->row_id_array_size())); } status = DBWrapper::DB()->InsertVectors(insert_param_->table_name(), vec_count, vec_f.data(), vec_ids); @@ -525,12 +505,12 @@ InsertTask::OnExecute() { auto ids_size = record_ids_->vector_id_array_size(); if (ids_size != vec_count) { - std::string msg = "Add " + std::to_string(vec_count) + " vectors but only return " - + std::to_string(ids_size) + " id"; + std::string msg = + "Add " + std::to_string(vec_count) + " vectors but only return " + std::to_string(ids_size) + " id"; return Status(SERVER_ILLEGAL_VECTOR_ID, msg); } - //step 6: update table flag + // step 6: update table flag user_provide_ids ? table_info.flag_ |= engine::meta::FLAG_MASK_HAS_USERID : table_info.flag_ |= engine::meta::FLAG_MASK_NO_USERID; status = DBWrapper::DB()->UpdateTableFlag(insert_param_->table_name(), table_info.flag_); @@ -541,7 +521,7 @@ InsertTask::OnExecute() { rc.RecordSection("add vectors to engine"); rc.ElapseFromBegin("total cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -549,9 +529,8 @@ InsertTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -SearchTask::SearchTask(const ::milvus::grpc::SearchParam *search_vector_infos, - const std::vector &file_id_array, - ::milvus::grpc::TopKQueryResultList *response) +SearchTask::SearchTask(const ::milvus::grpc::SearchParam* search_vector_infos, + const std::vector& file_id_array, ::milvus::grpc::TopKQueryResultList* response) : GrpcBaseTask(DQL_TASK_GROUP), search_param_(search_vector_infos), file_id_array_(file_id_array), @@ -559,9 +538,8 @@ SearchTask::SearchTask(const ::milvus::grpc::SearchParam *search_vector_infos, } BaseTaskPtr -SearchTask::Create(const ::milvus::grpc::SearchParam *search_vector_infos, - const std::vector &file_id_array, - ::milvus::grpc::TopKQueryResultList *response) { +SearchTask::Create(const ::milvus::grpc::SearchParam* search_vector_infos, + const std::vector& file_id_array, ::milvus::grpc::TopKQueryResultList* response) { if (search_vector_infos == nullptr) { SERVER_LOG_ERROR << "grpc input is null!"; return nullptr; @@ -578,14 +556,14 @@ SearchTask::OnExecute() { std::string hdr = "SearchTask(k=" + std::to_string(top_k) + ", nprob=" + std::to_string(nprobe) + ")"; TimeRecorder rc(hdr); - //step 1: check table name + // step 1: check table name std::string table_name_ = search_param_->table_name(); auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; } - //step 2: check table existence + // step 2: check table existence engine::meta::TableSchema table_info; table_info.table_id_ = table_name_; status = DBWrapper::DB()->DescribeTable(table_info); @@ -597,7 +575,7 @@ SearchTask::OnExecute() { } } - //step 3: check search parameter + // step 3: check search parameter status = ValidationUtil::ValidateSearchTopk(top_k, table_info); if (!status.ok()) { return status; @@ -612,7 +590,7 @@ SearchTask::OnExecute() { return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); } - //step 4: check date range, and convert to db dates + // step 4: check date range, and convert to db dates std::vector dates; std::vector<::milvus::grpc::Range> range_array; for (size_t i = 0; i < search_param_->query_range_array_size(); i++) { @@ -626,8 +604,7 @@ SearchTask::OnExecute() { rc.RecordSection("check validation"); - - //step 5: prepare float data + // step 5: prepare float data auto record_array_size = search_param_->query_record_array_size(); std::vector vec_f(record_array_size * table_info.dimension_, 0); for (size_t i = 0; i < record_array_size; i++) { @@ -637,33 +614,32 @@ SearchTask::OnExecute() { uint64_t query_vec_dim = search_param_->query_record_array(i).vector_data().size(); if (query_vec_dim != table_info.dimension_) { ErrorCode error_code = SERVER_INVALID_VECTOR_DIMENSION; - std::string error_msg = "Invalid row record dimension: " + std::to_string(query_vec_dim) - + " vs. table dimension:" + std::to_string(table_info.dimension_); + std::string error_msg = "Invalid row record dimension: " + std::to_string(query_vec_dim) + + " vs. table dimension:" + std::to_string(table_info.dimension_); return Status(error_code, error_msg); } - memcpy(&vec_f[i * table_info.dimension_], - search_param_->query_record_array(i).vector_data().data(), + memcpy(&vec_f[i * table_info.dimension_], search_param_->query_record_array(i).vector_data().data(), table_info.dimension_ * sizeof(float)); } rc.RecordSection("prepare vector data"); - //step 6: search vectors + // step 6: search vectors engine::QueryResults results; - auto record_count = (uint64_t) search_param_->query_record_array().size(); + auto record_count = (uint64_t)search_param_->query_record_array().size(); #ifdef MILVUS_ENABLE_PROFILING - std::string fname = "/tmp/search_nq_" + std::to_string(this->search_param_->query_record_array_size()) - + ".profiling"; + std::string fname = + "/tmp/search_nq_" + std::to_string(this->search_param_->query_record_array_size()) + ".profiling"; ProfilerStart(fname.c_str()); #endif if (file_id_array_.empty()) { - status = DBWrapper::DB()->Query(table_name_, (size_t) top_k, record_count, nprobe, - vec_f.data(), dates, results); + status = + DBWrapper::DB()->Query(table_name_, (size_t)top_k, record_count, nprobe, vec_f.data(), dates, results); } else { - status = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t) top_k, - record_count, nprobe, vec_f.data(), dates, results); + status = DBWrapper::DB()->Query(table_name_, file_id_array_, (size_t)top_k, record_count, nprobe, + vec_f.data(), dates, results); } #ifdef MILVUS_ENABLE_PROFILING @@ -676,29 +652,29 @@ SearchTask::OnExecute() { } if (results.empty()) { - return Status::OK(); //empty table + return Status::OK(); // empty table } if (results.size() != record_count) { - std::string msg = "Search " + std::to_string(record_count) + " vectors but only return " - + std::to_string(results.size()) + " results"; + std::string msg = "Search " + std::to_string(record_count) + " vectors but only return " + + std::to_string(results.size()) + " results"; return Status(SERVER_ILLEGAL_SEARCH_RESULT, msg); } - //step 7: construct result array - for (auto &result : results) { - ::milvus::grpc::TopKQueryResult *topk_query_result = topk_result_list->add_topk_query_result(); - for (auto &pair : result) { - ::milvus::grpc::QueryResult *grpc_result = topk_query_result->add_query_result_arrays(); + // step 7: construct result array + for (auto& result : results) { + ::milvus::grpc::TopKQueryResult* topk_query_result = topk_result_list->add_topk_query_result(); + for (auto& pair : result) { + ::milvus::grpc::QueryResult* grpc_result = topk_query_result->add_query_result_arrays(); grpc_result->set_id(pair.first); grpc_result->set_distance(pair.second); } } - //step 8: print time cost percent + // step 8: print time cost percent rc.RecordSection("construct result and send"); rc.ElapseFromBegin("totally cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -706,14 +682,12 @@ SearchTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -CountTableTask::CountTableTask(const std::string &table_name, int64_t &row_count) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_(table_name), - row_count_(row_count) { +CountTableTask::CountTableTask(const std::string& table_name, int64_t& row_count) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name), row_count_(row_count) { } BaseTaskPtr -CountTableTask::Create(const std::string &table_name, int64_t &row_count) { +CountTableTask::Create(const std::string& table_name, int64_t& row_count) { return std::shared_ptr(new CountTableTask(table_name, row_count)); } @@ -722,23 +696,27 @@ CountTableTask::OnExecute() { try { TimeRecorder rc("GetTableRowCountTask"); - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; } - //step 2: get row count + // step 2: get row count uint64_t row_count = 0; status = DBWrapper::DB()->GetTableRowCount(table_name_, row_count); if (!status.ok()) { - return status; + if (status.code(), DB_NOT_FOUND) { + return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + } else { + return status; + } } - row_count_ = (int64_t) row_count; + row_count_ = static_cast(row_count); rc.ElapseFromBegin("total cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -746,14 +724,12 @@ CountTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -CmdTask::CmdTask(const std::string &cmd, std::string &result) - : GrpcBaseTask(PING_TASK_GROUP), - cmd_(cmd), - result_(result) { +CmdTask::CmdTask(const std::string& cmd, std::string& result) + : GrpcBaseTask(PING_TASK_GROUP), cmd_(cmd), result_(result) { } BaseTaskPtr -CmdTask::Create(const std::string &cmd, std::string &result) { +CmdTask::Create(const std::string& cmd, std::string& result) { return std::shared_ptr(new CmdTask(cmd, result)); } @@ -771,13 +747,12 @@ CmdTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -DeleteByRangeTask::DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - delete_by_range_param_(delete_by_range_param) { +DeleteByRangeTask::DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam* delete_by_range_param) + : GrpcBaseTask(DDL_DML_TASK_GROUP), delete_by_range_param_(delete_by_range_param) { } BaseTaskPtr -DeleteByRangeTask::Create(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param) { +DeleteByRangeTask::Create(const ::milvus::grpc::DeleteByRangeParam* delete_by_range_param) { if (delete_by_range_param == nullptr) { SERVER_LOG_ERROR << "grpc input is null!"; return nullptr; @@ -791,14 +766,14 @@ DeleteByRangeTask::OnExecute() { try { TimeRecorder rc("DeleteByRangeTask"); - //step 1: check arguments + // step 1: check arguments std::string table_name = delete_by_range_param_->table_name(); auto status = ValidationUtil::ValidateTableName(table_name); if (!status.ok()) { return status; } - //step 2: check table existence + // step 2: check table existence engine::meta::TableSchema table_info; table_info.table_id_ = table_name; status = DBWrapper::DB()->DescribeTable(table_info); @@ -812,7 +787,7 @@ DeleteByRangeTask::OnExecute() { rc.ElapseFromBegin("check validation"); - //step 3: check date range, and convert to db dates + // step 3: check date range, and convert to db dates std::vector dates; ErrorCode error_code = SERVER_SUCCESS; std::string error_msg; @@ -832,7 +807,7 @@ DeleteByRangeTask::OnExecute() { if (!status.ok()) { return status; } - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -840,13 +815,12 @@ DeleteByRangeTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -PreloadTableTask::PreloadTableTask(const std::string &table_name) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_(table_name) { +PreloadTableTask::PreloadTableTask(const std::string& table_name) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name) { } BaseTaskPtr -PreloadTableTask::Create(const std::string &table_name) { +PreloadTableTask::Create(const std::string& table_name) { return std::shared_ptr(new PreloadTableTask(table_name)); } @@ -855,20 +829,20 @@ PreloadTableTask::OnExecute() { try { TimeRecorder rc("PreloadTableTask"); - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; } - //step 2: check table existence + // step 2: check table existence status = DBWrapper::DB()->PreloadTable(table_name_); if (!status.ok()) { return status; } rc.ElapseFromBegin("totally cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -876,16 +850,12 @@ PreloadTableTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -DescribeIndexTask::DescribeIndexTask(const std::string &table_name, - ::milvus::grpc::IndexParam *index_param) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_(table_name), - index_param_(index_param) { +DescribeIndexTask::DescribeIndexTask(const std::string& table_name, ::milvus::grpc::IndexParam* index_param) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name), index_param_(index_param) { } BaseTaskPtr -DescribeIndexTask::Create(const std::string &table_name, - ::milvus::grpc::IndexParam *index_param) { +DescribeIndexTask::Create(const std::string& table_name, ::milvus::grpc::IndexParam* index_param) { return std::shared_ptr(new DescribeIndexTask(table_name, index_param)); } @@ -894,13 +864,13 @@ DescribeIndexTask::OnExecute() { try { TimeRecorder rc("DescribeIndexTask"); - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; } - //step 2: check table existence + // step 2: check table existence engine::TableIndex index; status = DBWrapper::DB()->DescribeIndex(table_name_, index); if (!status.ok()) { @@ -912,7 +882,7 @@ DescribeIndexTask::OnExecute() { index_param_->mutable_index()->set_nlist(index.nlist_); rc.ElapseFromBegin("totally cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } @@ -920,13 +890,12 @@ DescribeIndexTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -DropIndexTask::DropIndexTask(const std::string &table_name) - : GrpcBaseTask(DDL_DML_TASK_GROUP), - table_name_(table_name) { +DropIndexTask::DropIndexTask(const std::string& table_name) + : GrpcBaseTask(DDL_DML_TASK_GROUP), table_name_(table_name) { } BaseTaskPtr -DropIndexTask::Create(const std::string &table_name) { +DropIndexTask::Create(const std::string& table_name) { return std::shared_ptr(new DropIndexTask(table_name)); } @@ -935,7 +904,7 @@ DropIndexTask::OnExecute() { try { TimeRecorder rc("DropIndexTask"); - //step 1: check arguments + // step 1: check arguments auto status = ValidationUtil::ValidateTableName(table_name_); if (!status.ok()) { return status; @@ -951,21 +920,20 @@ DropIndexTask::OnExecute() { return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); } - //step 2: check table existence + // step 2: check table existence status = DBWrapper::DB()->DropIndex(table_name_); if (!status.ok()) { return status; } rc.ElapseFromBegin("totally cost"); - } catch (std::exception &ex) { + } catch (std::exception& ex) { return Status(SERVER_UNEXPECTED_ERROR, ex.what()); } return Status::OK(); } -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcRequestTask.h b/cpp/src/server/grpc_impl/GrpcRequestTask.h index 4c8c038d444ff87bb8037dd5d507300e3b0b4660..ad2828ebf3471b3ed028abe586dd7b6b1907a577 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestTask.h +++ b/cpp/src/server/grpc_impl/GrpcRequestTask.h @@ -17,19 +17,18 @@ #pragma once +#include "db/Types.h" #include "server/grpc_impl/GrpcRequestScheduler.h" #include "utils/Status.h" -#include "db/Types.h" #include "grpc/gen-milvus/milvus.grpc.pb.h" #include "grpc/gen-status/status.pb.h" #include #include -#include #include +#include -namespace zilliz { namespace milvus { namespace server { namespace grpc { @@ -38,60 +37,60 @@ namespace grpc { class CreateTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::TableSchema *schema); + Create(const ::milvus::grpc::TableSchema* schema); protected: - explicit CreateTableTask(const ::milvus::grpc::TableSchema *request); + explicit CreateTableTask(const ::milvus::grpc::TableSchema* schema); Status OnExecute() override; private: - const ::milvus::grpc::TableSchema *schema_; + const ::milvus::grpc::TableSchema* schema_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class HasTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name, bool &has_table); + Create(const std::string& table_name, bool& has_table); protected: - HasTableTask(const std::string &request, bool &has_table); + HasTableTask(const std::string& table_name, bool& has_table); Status OnExecute() override; private: std::string table_name_; - bool &has_table_; + bool& has_table_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class DescribeTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name, ::milvus::grpc::TableSchema *schema); + Create(const std::string& table_name, ::milvus::grpc::TableSchema* schema); protected: - DescribeTableTask(const std::string &table_name, ::milvus::grpc::TableSchema *schema); + DescribeTableTask(const std::string& table_name, ::milvus::grpc::TableSchema* schema); Status OnExecute() override; private: std::string table_name_; - ::milvus::grpc::TableSchema *schema_; + ::milvus::grpc::TableSchema* schema_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class DropTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name); + Create(const std::string& table_name); protected: - explicit DropTableTask(const std::string &table_name); + explicit DropTableTask(const std::string& table_name); Status OnExecute() override; @@ -104,133 +103,129 @@ class DropTableTask : public GrpcBaseTask { class CreateIndexTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::IndexParam *index_Param); + Create(const ::milvus::grpc::IndexParam* index_param); protected: - explicit CreateIndexTask(const ::milvus::grpc::IndexParam *index_Param); + explicit CreateIndexTask(const ::milvus::grpc::IndexParam* index_param); Status OnExecute() override; private: - const ::milvus::grpc::IndexParam *index_param_; + const ::milvus::grpc::IndexParam* index_param_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class ShowTablesTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(::milvus::grpc::TableNameList *table_name_list); + Create(::milvus::grpc::TableNameList* table_name_list); protected: - explicit ShowTablesTask(::milvus::grpc::TableNameList *table_name_list); + explicit ShowTablesTask(::milvus::grpc::TableNameList* table_name_list); Status OnExecute() override; private: - ::milvus::grpc::TableNameList *table_name_list_; + ::milvus::grpc::TableNameList* table_name_list_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class InsertTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::InsertParam *insert_Param, - ::milvus::grpc::VectorIds *record_ids_); + Create(const ::milvus::grpc::InsertParam* insert_param, ::milvus::grpc::VectorIds* record_ids); protected: - InsertTask(const ::milvus::grpc::InsertParam *insert_Param, - ::milvus::grpc::VectorIds *record_ids_); + InsertTask(const ::milvus::grpc::InsertParam* insert_param, ::milvus::grpc::VectorIds* record_ids); Status OnExecute() override; private: - const ::milvus::grpc::InsertParam *insert_param_; - ::milvus::grpc::VectorIds *record_ids_; + const ::milvus::grpc::InsertParam* insert_param_; + ::milvus::grpc::VectorIds* record_ids_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class SearchTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::SearchParam *search_param, - const std::vector &file_id_array, - ::milvus::grpc::TopKQueryResultList *response); + Create(const ::milvus::grpc::SearchParam* search_param, const std::vector& file_id_array, + ::milvus::grpc::TopKQueryResultList* response); protected: - SearchTask(const ::milvus::grpc::SearchParam *search_param, - const std::vector &file_id_array, - ::milvus::grpc::TopKQueryResultList *response); + SearchTask(const ::milvus::grpc::SearchParam* search_param, const std::vector& file_id_array, + ::milvus::grpc::TopKQueryResultList* response); Status OnExecute() override; private: - const ::milvus::grpc::SearchParam *search_param_; + const ::milvus::grpc::SearchParam* search_param_; std::vector file_id_array_; - ::milvus::grpc::TopKQueryResultList *topk_result_list; + ::milvus::grpc::TopKQueryResultList* topk_result_list; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class CountTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name, int64_t &row_count); + Create(const std::string& table_name, int64_t& row_count); protected: - CountTableTask(const std::string &table_name, int64_t &row_count); + CountTableTask(const std::string& table_name, int64_t& row_count); Status OnExecute() override; private: std::string table_name_; - int64_t &row_count_; + int64_t& row_count_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class CmdTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &cmd, std::string &result); + Create(const std::string& cmd, std::string& result); protected: - CmdTask(const std::string &cmd, std::string &result); + CmdTask(const std::string& cmd, std::string& result); Status OnExecute() override; private: std::string cmd_; - std::string &result_; + std::string& result_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class DeleteByRangeTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param); + Create(const ::milvus::grpc::DeleteByRangeParam* delete_by_range_param); protected: - explicit DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param); + explicit DeleteByRangeTask(const ::milvus::grpc::DeleteByRangeParam* delete_by_range_param); Status OnExecute() override; private: - const ::milvus::grpc::DeleteByRangeParam *delete_by_range_param_; + const ::milvus::grpc::DeleteByRangeParam* delete_by_range_param_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class PreloadTableTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name); + Create(const std::string& table_name); protected: - explicit PreloadTableTask(const std::string &table_name); + explicit PreloadTableTask(const std::string& table_name); Status OnExecute() override; @@ -243,29 +238,27 @@ class PreloadTableTask : public GrpcBaseTask { class DescribeIndexTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name, - ::milvus::grpc::IndexParam *index_param); + Create(const std::string& table_name, ::milvus::grpc::IndexParam* index_param); protected: - DescribeIndexTask(const std::string &table_name, - ::milvus::grpc::IndexParam *index_param); + DescribeIndexTask(const std::string& table_name, ::milvus::grpc::IndexParam* index_param); Status OnExecute() override; private: std::string table_name_; - ::milvus::grpc::IndexParam *index_param_; + ::milvus::grpc::IndexParam* index_param_; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class DropIndexTask : public GrpcBaseTask { public: static BaseTaskPtr - Create(const std::string &table_name); + Create(const std::string& table_name); protected: - explicit DropIndexTask(const std::string &table_name); + explicit DropIndexTask(const std::string& table_name); Status OnExecute() override; @@ -274,7 +267,6 @@ class DropIndexTask : public GrpcBaseTask { std::string table_name_; }; -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcServer.cpp b/cpp/src/server/grpc_impl/GrpcServer.cpp index 065271dd5534d451e2c71bbdc397932fc286b414..5e0c5f3169bc2a026bcf7104bc5456e46da5eea2 100644 --- a/cpp/src/server/grpc_impl/GrpcServer.cpp +++ b/cpp/src/server/grpc_impl/GrpcServer.cpp @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -#include "grpc/gen-milvus/milvus.grpc.pb.h" #include "server/grpc_impl/GrpcServer.h" +#include "GrpcRequestHandler.h" +#include "grpc/gen-milvus/milvus.grpc.pb.h" #include "server/Config.h" #include "server/DBWrapper.h" #include "utils/Log.h" -#include "GrpcRequestHandler.h" #include #include @@ -36,21 +36,22 @@ #include #include -namespace zilliz { namespace milvus { namespace server { namespace grpc { constexpr int64_t MESSAGE_SIZE = -1; -//this class is to check port occupation during server start +// this class is to check port occupation during server start class NoReusePortOption : public ::grpc::ServerBuilderOption { public: - void UpdateArguments(::grpc::ChannelArguments *args) override { + void + UpdateArguments(::grpc::ChannelArguments* args) override { args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0); } - void UpdatePlugins(std::vector> *plugins) override { + void + UpdatePlugins(std::vector>* plugins) override { } }; @@ -70,7 +71,7 @@ GrpcServer::Stop() { Status GrpcServer::StartService() { - Config &config = Config::GetInstance(); + Config& config = Config::GetInstance(); std::string address, port; Status s; @@ -87,7 +88,7 @@ GrpcServer::StartService() { ::grpc::ServerBuilder builder; builder.SetOption(std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); - builder.SetMaxReceiveMessageSize(MESSAGE_SIZE); //default 4 * 1024 * 1024 + builder.SetMaxReceiveMessageSize(MESSAGE_SIZE); // default 4 * 1024 * 1024 builder.SetMaxSendMessageSize(MESSAGE_SIZE); builder.SetCompressionAlgorithmSupportStatus(GRPC_COMPRESS_STREAM_GZIP, true); @@ -114,7 +115,6 @@ GrpcServer::StopService() { return Status::OK(); } -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/server/grpc_impl/GrpcServer.h b/cpp/src/server/grpc_impl/GrpcServer.h index 9101f144b3b79a1f45cba894feee177aee9adec7..aeaf9f0dcaf406a62ee716f5aac57f04a7feef53 100644 --- a/cpp/src/server/grpc_impl/GrpcServer.h +++ b/cpp/src/server/grpc_impl/GrpcServer.h @@ -19,40 +19,43 @@ #include "utils/Status.h" -#include +#include #include +#include #include #include -#include -namespace zilliz { namespace milvus { namespace server { namespace grpc { class GrpcServer { public: - static GrpcServer &GetInstance() { + static GrpcServer& + GetInstance() { static GrpcServer grpc_server; return grpc_server; } - void Start(); - void Stop(); + void + Start(); + void + Stop(); private: GrpcServer() = default; ~GrpcServer() = default; - Status StartService(); - Status StopService(); + Status + StartService(); + Status + StopService(); private: std::unique_ptr<::grpc::Server> server_ptr_; std::shared_ptr thread_ptr_; }; -} // namespace grpc -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace grpc +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/BlockingQueue.h b/cpp/src/utils/BlockingQueue.h index b98fe28ef8da35b55ff1ec7760e9636f9d9a539a..dc7968fcb6eb10e5448fcf05f4de6e654bba2261 100644 --- a/cpp/src/utils/BlockingQueue.h +++ b/cpp/src/utils/BlockingQueue.h @@ -23,33 +23,40 @@ #include #include -namespace zilliz { namespace milvus { namespace server { -template +template class BlockingQueue { public: BlockingQueue() : mtx(), full_(), empty_() { } - BlockingQueue(const BlockingQueue &rhs) = delete; + BlockingQueue(const BlockingQueue& rhs) = delete; - BlockingQueue &operator=(const BlockingQueue &rhs) = delete; + BlockingQueue& + operator=(const BlockingQueue& rhs) = delete; - void Put(const T &task); + void + Put(const T& task); - T Take(); + T + Take(); - T Front(); + T + Front(); - T Back(); + T + Back(); - size_t Size(); + size_t + Size(); - bool Empty(); + bool + Empty(); - void SetCapacity(const size_t capacity); + void + SetCapacity(const size_t capacity); private: mutable std::mutex mtx; @@ -59,8 +66,7 @@ class BlockingQueue { size_t capacity_ = 32; }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus #include "./BlockingQueue.inl" diff --git a/cpp/src/utils/BlockingQueue.inl b/cpp/src/utils/BlockingQueue.inl index ed15aac77a0597b7c16ba83d6f6e17fbfaaff005..c4318c5fc631ea82b6eaf8bf048bc0f489dfcd74 100644 --- a/cpp/src/utils/BlockingQueue.inl +++ b/cpp/src/utils/BlockingQueue.inl @@ -18,7 +18,7 @@ #pragma once -namespace zilliz { + namespace milvus { namespace server { @@ -94,5 +94,5 @@ BlockingQueue::SetCapacity(const size_t capacity) { } // namespace server } // namespace milvus -} // namespace zilliz + diff --git a/cpp/src/utils/CommonUtil.cpp b/cpp/src/utils/CommonUtil.cpp index 0116e321e0cd4cffb29091c37db9aca139e4a234..fbf3112aeb29b80e6e228fc322ab09f8b5d1df94 100644 --- a/cpp/src/utils/CommonUtil.cpp +++ b/cpp/src/utils/CommonUtil.cpp @@ -18,15 +18,15 @@ #include "utils/CommonUtil.h" #include "utils/Log.h" -#include -#include -#include -#include -#include #include +#include #include -#include +#include +#include #include +#include +#include +#include #include "boost/filesystem.hpp" @@ -38,36 +38,36 @@ #define THREAD_MULTIPLY_CPU 1 #endif -namespace zilliz { namespace milvus { namespace server { namespace fs = boost::filesystem; bool -CommonUtil::GetSystemMemInfo(uint64_t &total_mem, uint64_t &free_mem) { +CommonUtil::GetSystemMemInfo(uint64_t& total_mem, uint64_t& free_mem) { struct sysinfo info; int ret = sysinfo(&info); total_mem = info.totalram; free_mem = info.freeram; - return ret == 0;//succeed 0, failed -1 + return ret == 0; // succeed 0, failed -1 } bool -CommonUtil::GetSystemAvailableThreads(uint32_t &thread_count) { - //threadCnt = std::thread::hardware_concurrency(); +CommonUtil::GetSystemAvailableThreads(uint32_t& thread_count) { + // threadCnt = std::thread::hardware_concurrency(); thread_count = sysconf(_SC_NPROCESSORS_CONF); thread_count *= THREAD_MULTIPLY_CPU; - if (thread_count == 0) + if (thread_count == 0) { thread_count = 8; + } return true; } bool -CommonUtil::IsDirectoryExist(const std::string &path) { - DIR *dp = nullptr; +CommonUtil::IsDirectoryExist(const std::string& path) { + DIR* dp = nullptr; if ((dp = opendir(path.c_str())) == nullptr) { return false; } @@ -77,7 +77,7 @@ CommonUtil::IsDirectoryExist(const std::string &path) { } Status -CommonUtil::CreateDirectory(const std::string &path) { +CommonUtil::CreateDirectory(const std::string& path) { if (path.empty()) { return Status::OK(); } @@ -85,7 +85,7 @@ CommonUtil::CreateDirectory(const std::string &path) { struct stat directory_stat; int status = stat(path.c_str(), &directory_stat); if (status == 0) { - return Status::OK();//already exist + return Status::OK(); // already exist } fs::path fs_path(path); @@ -97,7 +97,7 @@ CommonUtil::CreateDirectory(const std::string &path) { status = stat(path.c_str(), &directory_stat); if (status == 0) { - return Status::OK();//already exist + return Status::OK(); // already exist } int makeOK = mkdir(path.c_str(), S_IRWXU | S_IRGRP | S_IROTH); @@ -110,17 +110,16 @@ CommonUtil::CreateDirectory(const std::string &path) { namespace { void -RemoveDirectory(const std::string &path) { - DIR *dir = nullptr; - struct dirent *dmsg; +RemoveDirectory(const std::string& path) { + DIR* dir = nullptr; + struct dirent* dmsg; const int32_t buf_size = 256; char file_name[buf_size]; std::string folder_name = path + "/%s"; if ((dir = opendir(path.c_str())) != nullptr) { while ((dmsg = readdir(dir)) != nullptr) { - if (strcmp(dmsg->d_name, ".") != 0 - && strcmp(dmsg->d_name, "..") != 0) { + if (strcmp(dmsg->d_name, ".") != 0 && strcmp(dmsg->d_name, "..") != 0) { snprintf(file_name, buf_size, folder_name.c_str(), dmsg->d_name); std::string tmp = file_name; if (tmp.find(".") == std::string::npos) { @@ -136,10 +135,10 @@ RemoveDirectory(const std::string &path) { } remove(path.c_str()); } -} // namespace +} // namespace Status -CommonUtil::DeleteDirectory(const std::string &path) { +CommonUtil::DeleteDirectory(const std::string& path) { if (path.empty()) { return Status::OK(); } @@ -155,18 +154,18 @@ CommonUtil::DeleteDirectory(const std::string &path) { } bool -CommonUtil::IsFileExist(const std::string &path) { +CommonUtil::IsFileExist(const std::string& path) { return (access(path.c_str(), F_OK) == 0); } uint64_t -CommonUtil::GetFileSize(const std::string &path) { +CommonUtil::GetFileSize(const std::string& path) { struct stat file_info; if (stat(path.c_str(), &file_info) < 0) { return 0; - } else { - return (uint64_t) file_info.st_size; } + + return static_cast(file_info.st_size); } std::string @@ -195,21 +194,13 @@ CommonUtil::GetExePath() { } bool -CommonUtil::TimeStrToTime(const std::string &time_str, - time_t &time_integer, - tm &time_struct, - const std::string &format) { +CommonUtil::TimeStrToTime(const std::string& time_str, time_t& time_integer, tm& time_struct, + const std::string& format) { time_integer = 0; memset(&time_struct, 0, sizeof(tm)); - int ret = sscanf(time_str.c_str(), - format.c_str(), - &(time_struct.tm_year), - &(time_struct.tm_mon), - &(time_struct.tm_mday), - &(time_struct.tm_hour), - &(time_struct.tm_min), - &(time_struct.tm_sec)); + int ret = sscanf(time_str.c_str(), format.c_str(), &(time_struct.tm_year), &(time_struct.tm_mon), + &(time_struct.tm_mday), &(time_struct.tm_hour), &(time_struct.tm_min), &(time_struct.tm_sec)); if (ret <= 0) { return false; } @@ -222,15 +213,14 @@ CommonUtil::TimeStrToTime(const std::string &time_str, } void -CommonUtil::ConvertTime(time_t time_integer, tm &time_struct) { +CommonUtil::ConvertTime(time_t time_integer, tm& time_struct) { localtime_r(&time_integer, &time_struct); } void -CommonUtil::ConvertTime(tm time_struct, time_t &time_integer) { +CommonUtil::ConvertTime(tm time_struct, time_t& time_integer) { time_integer = mktime(&time_struct); } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/CommonUtil.h b/cpp/src/utils/CommonUtil.h old mode 100755 new mode 100644 index b059067d5087b014af27231e44e5df3f5c44c3e0..939bdd6d31e8ee46ead21b70bafbc930c17be9d8 --- a/cpp/src/utils/CommonUtil.h +++ b/cpp/src/utils/CommonUtil.h @@ -19,37 +19,44 @@ #include "utils/Status.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace server { class CommonUtil { public: - static bool GetSystemMemInfo(uint64_t &total_mem, uint64_t &free_mem); - static bool GetSystemAvailableThreads(uint32_t &thread_count); - - static bool IsFileExist(const std::string &path); - static uint64_t GetFileSize(const std::string &path); - static bool IsDirectoryExist(const std::string &path); - static Status CreateDirectory(const std::string &path); - static Status DeleteDirectory(const std::string &path); - - static std::string GetFileName(std::string filename); - static std::string GetExePath(); - - static bool TimeStrToTime(const std::string &time_str, - time_t &time_integer, - tm &time_struct, - const std::string &format = "%d-%d-%d %d:%d:%d"); - - static void ConvertTime(time_t time_integer, tm &time_struct); - static void ConvertTime(tm time_struct, time_t &time_integer); + static bool + GetSystemMemInfo(uint64_t& total_mem, uint64_t& free_mem); + static bool + GetSystemAvailableThreads(uint32_t& thread_count); + + static bool + IsFileExist(const std::string& path); + static uint64_t + GetFileSize(const std::string& path); + static bool + IsDirectoryExist(const std::string& path); + static Status + CreateDirectory(const std::string& path); + static Status + DeleteDirectory(const std::string& path); + + static std::string + GetFileName(std::string filename); + static std::string + GetExePath(); + + static bool + TimeStrToTime(const std::string& time_str, time_t& time_integer, tm& time_struct, + const std::string& format = "%d-%d-%d %d:%d:%d"); + + static void + ConvertTime(time_t time_integer, tm& time_struct); + static void + ConvertTime(tm time_struct, time_t& time_integer); }; -} // namespace server -} // namespace milvus -} // namespace zilliz - +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/Error.h b/cpp/src/utils/Error.h index d0bcc5cf8d8c703bb5f92857c2fa9ae52cb10d40..dfc400ca9a07fc148b146bceb9157a3165029079 100644 --- a/cpp/src/utils/Error.h +++ b/cpp/src/utils/Error.h @@ -21,13 +21,12 @@ #include #include -namespace zilliz { namespace milvus { using ErrorCode = int32_t; constexpr ErrorCode SERVER_SUCCESS = 0; -constexpr ErrorCode SERVER_ERROR_CODE_BASE = 0x30000; +constexpr ErrorCode SERVER_ERROR_CODE_BASE = 30000; constexpr ErrorCode ToServerErrorCode(const ErrorCode error_code) { @@ -35,7 +34,7 @@ ToServerErrorCode(const ErrorCode error_code) { } constexpr ErrorCode DB_SUCCESS = 0; -constexpr ErrorCode DB_ERROR_CODE_BASE = 0x40000; +constexpr ErrorCode DB_ERROR_CODE_BASE = 40000; constexpr ErrorCode ToDbErrorCode(const ErrorCode error_code) { @@ -43,21 +42,20 @@ ToDbErrorCode(const ErrorCode error_code) { } constexpr ErrorCode KNOWHERE_SUCCESS = 0; -constexpr ErrorCode KNOWHERE_ERROR_CODE_BASE = 0x50000; +constexpr ErrorCode KNOWHERE_ERROR_CODE_BASE = 50000; constexpr ErrorCode ToKnowhereErrorCode(const ErrorCode error_code) { return KNOWHERE_ERROR_CODE_BASE + error_code; } -//server error code +// server error code constexpr ErrorCode SERVER_UNEXPECTED_ERROR = ToServerErrorCode(1); constexpr ErrorCode SERVER_UNSUPPORTED_ERROR = ToServerErrorCode(2); constexpr ErrorCode SERVER_NULL_POINTER = ToServerErrorCode(3); constexpr ErrorCode SERVER_INVALID_ARGUMENT = ToServerErrorCode(4); constexpr ErrorCode SERVER_FILE_NOT_FOUND = ToServerErrorCode(5); constexpr ErrorCode SERVER_NOT_IMPLEMENT = ToServerErrorCode(6); -constexpr ErrorCode SERVER_BLOCKING_QUEUE_EMPTY = ToServerErrorCode(7); constexpr ErrorCode SERVER_CANNOT_CREATE_FOLDER = ToServerErrorCode(8); constexpr ErrorCode SERVER_CANNOT_CREATE_FILE = ToServerErrorCode(9); constexpr ErrorCode SERVER_CANNOT_DELETE_FOLDER = ToServerErrorCode(10); @@ -75,7 +73,7 @@ constexpr ErrorCode SERVER_INVALID_ROWRECORD_ARRAY = ToServerErrorCode(107); constexpr ErrorCode SERVER_INVALID_TOPK = ToServerErrorCode(108); constexpr ErrorCode SERVER_ILLEGAL_VECTOR_ID = ToServerErrorCode(109); constexpr ErrorCode SERVER_ILLEGAL_SEARCH_RESULT = ToServerErrorCode(110); -constexpr ErrorCode SERVER_CACHE_ERROR = ToServerErrorCode(111); +constexpr ErrorCode SERVER_CACHE_FULL = ToServerErrorCode(111); constexpr ErrorCode SERVER_WRITE_ERROR = ToServerErrorCode(112); constexpr ErrorCode SERVER_INVALID_NPROBE = ToServerErrorCode(113); constexpr ErrorCode SERVER_INVALID_INDEX_NLIST = ToServerErrorCode(114); @@ -83,7 +81,7 @@ constexpr ErrorCode SERVER_INVALID_INDEX_METRIC_TYPE = ToServerErrorCode(115); constexpr ErrorCode SERVER_INVALID_INDEX_FILE_SIZE = ToServerErrorCode(116); constexpr ErrorCode SERVER_OUT_OF_MEMORY = ToServerErrorCode(117); -//db error code +// db error code constexpr ErrorCode DB_META_TRANSACTION_FAILED = ToDbErrorCode(1); constexpr ErrorCode DB_ERROR = ToDbErrorCode(2); constexpr ErrorCode DB_NOT_FOUND = ToDbErrorCode(3); @@ -92,7 +90,7 @@ constexpr ErrorCode DB_INVALID_PATH = ToDbErrorCode(5); constexpr ErrorCode DB_INCOMPATIB_META = ToDbErrorCode(6); constexpr ErrorCode DB_INVALID_META_URI = ToDbErrorCode(7); -//knowhere error code +// knowhere error code constexpr ErrorCode KNOWHERE_ERROR = ToKnowhereErrorCode(1); constexpr ErrorCode KNOWHERE_INVALID_ARGUMENT = ToKnowhereErrorCode(2); constexpr ErrorCode KNOWHERE_UNEXPECTED_ERROR = ToKnowhereErrorCode(3); @@ -101,17 +99,18 @@ constexpr ErrorCode KNOWHERE_NO_SPACE = ToKnowhereErrorCode(4); namespace server { class ServerException : public std::exception { public: - ServerException(ErrorCode error_code, - const std::string &message = std::string()) + explicit ServerException(ErrorCode error_code, const std::string& message = std::string()) : error_code_(error_code), message_(message) { } public: - ErrorCode error_code() const { + ErrorCode + error_code() const { return error_code_; } - virtual const char *what() const noexcept { + virtual const char* + what() const noexcept { return message_.c_str(); } @@ -119,7 +118,6 @@ class ServerException : public std::exception { ErrorCode error_code_; std::string message_; }; -} // namespace server +} // namespace server } // namespace milvus -} // namespace zilliz diff --git a/cpp/src/utils/Exception.h b/cpp/src/utils/Exception.h index 7e30c372bc7ab5c54506606eaa3c5f8686cf1503..a2d8473fa37e9053754e69c7005cc6d9a308cbf5 100644 --- a/cpp/src/utils/Exception.h +++ b/cpp/src/utils/Exception.h @@ -22,21 +22,20 @@ #include #include -namespace zilliz { namespace milvus { class Exception : public std::exception { public: - Exception(ErrorCode code, const std::string &message) - : code_(code), - message_(message) { + Exception(ErrorCode code, const std::string& message) : code_(code), message_(message) { } - ErrorCode code() const throw() { + ErrorCode + code() const throw() { return code_; } - virtual const char *what() const throw() { + virtual const char* + what() const throw() { if (message_.empty()) { return "Default Exception."; } else { @@ -54,14 +53,11 @@ class Exception : public std::exception { class InvalidArgumentException : public Exception { public: - InvalidArgumentException() - : Exception(SERVER_INVALID_ARGUMENT, "Invalid Argument") { + InvalidArgumentException() : Exception(SERVER_INVALID_ARGUMENT, "Invalid Argument") { } - explicit InvalidArgumentException(const std::string &message) - : Exception(SERVER_INVALID_ARGUMENT, message) { + explicit InvalidArgumentException(const std::string& message) : Exception(SERVER_INVALID_ARGUMENT, message) { } }; -} // namespace milvus -} // namespace zilliz +} // namespace milvus diff --git a/cpp/src/utils/Log.h b/cpp/src/utils/Log.h index b1402d9e3e043b658fbcd3935ac89a6d6f422097..1dd116367a22666a4b8cc0f47ff3340108c3e7c1 100644 --- a/cpp/src/utils/Log.h +++ b/cpp/src/utils/Log.h @@ -19,7 +19,6 @@ #include "utils/easylogging++.h" -namespace zilliz { namespace milvus { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -52,5 +51,4 @@ namespace milvus { #define WRAPPER_LOG_ERROR LOG(ERROR) << WRAPPER_DOMAIN_NAME #define WRAPPER_LOG_FATAL LOG(FATAL) << WRAPPER_DOMAIN_NAME -} // namespace milvus -} // namespace zilliz +} // namespace milvus diff --git a/cpp/src/utils/LogUtil.cpp b/cpp/src/utils/LogUtil.cpp index 0e710a6e3aadb7aea6921b43ed596030d0beaf24..4a962f466ce845dc09e28f789cc67ed4ed920eaa 100644 --- a/cpp/src/utils/LogUtil.cpp +++ b/cpp/src/utils/LogUtil.cpp @@ -18,10 +18,9 @@ #include "utils/LogUtil.h" #include -#include #include +#include -namespace zilliz { namespace milvus { namespace server { @@ -32,20 +31,20 @@ static int warning_idx = 0; static int trace_idx = 0; static int error_idx = 0; static int fatal_idx = 0; -} +} // namespace // TODO(yzb) : change the easylogging library to get the log level from parameter rather than filename void -RolloutHandler(const char *filename, std::size_t size, el::Level level) { - char *dirc = strdup(filename); - char *basec = strdup(filename); - char *dir = dirname(dirc); - char *base = basename(basec); +RolloutHandler(const char* filename, std::size_t size, el::Level level) { + char* dirc = strdup(filename); + char* basec = strdup(filename); + char* dir = dirname(dirc); + char* base = basename(basec); std::string s(base); std::stringstream ss; - std::string - list[] = {"\\", " ", "\'", "\"", "*", "\?", "{", "}", ";", "<", ">", "|", "^", "&", "$", "#", "!", "`", "~"}; + std::string list[] = {"\\", " ", "\'", "\"", "*", "\?", "{", "}", ";", "<", + ">", "|", "^", "&", "$", "#", "!", "`", "~"}; std::string::size_type position; for (auto substr : list) { position = 0; @@ -82,7 +81,7 @@ RolloutHandler(const char *filename, std::size_t size, el::Level level) { } Status -InitLog(const std::string &log_config_file) { +InitLog(const std::string& log_config_file) { el::Configurations conf(log_config_file); el::Loggers::reconfigureAllLoggers(conf); @@ -93,6 +92,5 @@ InitLog(const std::string &log_config_file) { return Status::OK(); } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/LogUtil.h b/cpp/src/utils/LogUtil.h index e3f7bed51f7aceb3752094492821f08d28d06af3..9926939442385511b27f42acfb4bd39fa6fad439 100644 --- a/cpp/src/utils/LogUtil.h +++ b/cpp/src/utils/LogUtil.h @@ -20,26 +20,24 @@ #include "utils/Status.h" #include "utils/easylogging++.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace server { Status -InitLog(const std::string &log_config_file); +InitLog(const std::string& log_config_file); void -RolloutHandler(const char *filename, std::size_t size, el::Level level); +RolloutHandler(const char* filename, std::size_t size, el::Level level); #define SHOW_LOCATION #ifdef SHOW_LOCATION -#define LOCATION_INFO "[" << zilliz::sql::server::GetFileName(__FILE__) << ":" << __LINE__ << "] " +#define LOCATION_INFO "[" << sql::server::GetFileName(__FILE__) << ":" << __LINE__ << "] " #else #define LOCATION_INFO "" #endif -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/SignalUtil.cpp b/cpp/src/utils/SignalUtil.cpp index 5f74852995081641790651469cf6eeb2b4f49618..5531aaed27f898680e875cdb74bbba496ad2d2f8 100644 --- a/cpp/src/utils/SignalUtil.cpp +++ b/cpp/src/utils/SignalUtil.cpp @@ -19,11 +19,10 @@ #include "src/server/Server.h" #include "utils/Log.h" -#include -#include #include +#include +#include -namespace zilliz { namespace milvus { namespace server { @@ -34,7 +33,7 @@ SignalUtil::HandleSignal(int signum) { case SIGUSR2: { SERVER_LOG_INFO << "Server received signal: " << signum; - server::Server &server = server::Server::GetInstance(); + server::Server& server = server::Server::GetInstance(); server.Stop(); exit(0); @@ -43,7 +42,7 @@ SignalUtil::HandleSignal(int signum) { SERVER_LOG_INFO << "Server received critical signal: " << signum; SignalUtil::PrintStacktrace(); - server::Server &server = server::Server::GetInstance(); + server::Server& server = server::Server::GetInstance(); server.Stop(); exit(1); @@ -56,9 +55,9 @@ SignalUtil::PrintStacktrace() { SERVER_LOG_INFO << "Call stack:"; const int size = 32; - void *array[size]; + void* array[size]; int stack_num = backtrace(array, size); - char **stacktrace = backtrace_symbols(array, stack_num); + char** stacktrace = backtrace_symbols(array, stack_num); for (int i = 0; i < stack_num; ++i) { std::string info = stacktrace[i]; SERVER_LOG_INFO << info; @@ -66,6 +65,5 @@ SignalUtil::PrintStacktrace() { free(stacktrace); } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/SignalUtil.h b/cpp/src/utils/SignalUtil.h index c9cd41839b42a7a81635b99623100cdb32b8c66c..2ecfdecfba0be1f304043be8faef30666bbaee00 100644 --- a/cpp/src/utils/SignalUtil.h +++ b/cpp/src/utils/SignalUtil.h @@ -17,16 +17,16 @@ #pragma once -namespace zilliz { namespace milvus { namespace server { class SignalUtil { public: - static void HandleSignal(int signum); - static void PrintStacktrace(); + static void + HandleSignal(int signum); + static void + PrintStacktrace(); }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/Status.cpp b/cpp/src/utils/Status.cpp index 5b512b3369eac9e0e4790320b82bebc400b30540..ad97717cf7190422bc6db09e13723bbc5130cbbe 100644 --- a/cpp/src/utils/Status.cpp +++ b/cpp/src/utils/Status.cpp @@ -19,17 +19,16 @@ #include -namespace zilliz { namespace milvus { constexpr int CODE_WIDTH = sizeof(StatusCode); -Status::Status(StatusCode code, const std::string &msg) { - //4 bytes store code - //4 bytes store message length - //the left bytes store message string - const uint32_t length = (uint32_t) msg.size(); - char *result = new char[length + sizeof(length) + CODE_WIDTH]; +Status::Status(StatusCode code, const std::string& msg) { + // 4 bytes store code + // 4 bytes store message length + // the left bytes store message string + const uint32_t length = (uint32_t)msg.size(); + auto result = new char[length + sizeof(length) + CODE_WIDTH]; std::memcpy(result, &code, CODE_WIDTH); std::memcpy(result + CODE_WIDTH, &length, sizeof(length)); memcpy(result + sizeof(length) + CODE_WIDTH, msg.data(), length); @@ -37,38 +36,35 @@ Status::Status(StatusCode code, const std::string &msg) { state_ = result; } -Status::Status() - : state_(nullptr) { +Status::Status() : state_(nullptr) { } Status::~Status() { delete state_; } -Status::Status(const Status &s) - : state_(nullptr) { +Status::Status(const Status& s) : state_(nullptr) { CopyFrom(s); } -Status & -Status::operator=(const Status &s) { +Status& +Status::operator=(const Status& s) { CopyFrom(s); return *this; } -Status::Status(Status &&s) - : state_(nullptr) { +Status::Status(Status&& s) : state_(nullptr) { MoveFrom(s); } -Status & -Status::operator=(Status &&s) { +Status& +Status::operator=(Status&& s) { MoveFrom(s); return *this; } void -Status::CopyFrom(const Status &s) { +Status::CopyFrom(const Status& s) { delete state_; state_ = nullptr; if (s.state_ == nullptr) { @@ -79,11 +75,11 @@ Status::CopyFrom(const Status &s) { memcpy(&length, s.state_ + CODE_WIDTH, sizeof(length)); int buff_len = length + sizeof(length) + CODE_WIDTH; state_ = new char[buff_len]; - memcpy((void *) state_, (void *) s.state_, buff_len); + memcpy(state_, s.state_, buff_len); } void -Status::MoveFrom(Status &s) { +Status::MoveFrom(Status& s) { delete state_; state_ = s.state_; s.state_ = nullptr; @@ -113,19 +109,26 @@ Status::ToString() const { std::string result; switch (code()) { - case DB_SUCCESS:result = "OK "; + case DB_SUCCESS: + result = "OK "; break; - case DB_ERROR:result = "Error: "; + case DB_ERROR: + result = "Error: "; break; - case DB_META_TRANSACTION_FAILED:result = "Database error: "; + case DB_META_TRANSACTION_FAILED: + result = "Database error: "; break; - case DB_NOT_FOUND:result = "Not found: "; + case DB_NOT_FOUND: + result = "Not found: "; break; - case DB_ALREADY_EXIST:result = "Already exist: "; + case DB_ALREADY_EXIST: + result = "Already exist: "; break; - case DB_INVALID_PATH:result = "Invalid path: "; + case DB_INVALID_PATH: + result = "Invalid path: "; break; - default:result = "Error code(" + std::to_string(code()) + "): "; + default: + result = "Error code(" + std::to_string(code()) + "): "; break; } @@ -133,5 +136,4 @@ Status::ToString() const { return result; } -} // namespace milvus -} // namespace zilliz +} // namespace milvus diff --git a/cpp/src/utils/Status.h b/cpp/src/utils/Status.h index 8f8f238979763e265efb32c41ad1b9eb233a1892..07a12261bba03b8cea48f8f19ac3e7926b9743a6 100644 --- a/cpp/src/utils/Status.h +++ b/cpp/src/utils/Status.h @@ -15,33 +15,31 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "utils/Error.h" #include -namespace zilliz { namespace milvus { using StatusCode = ErrorCode; class Status { public: - Status(StatusCode code, const std::string &msg); + Status(StatusCode code, const std::string& msg); Status(); ~Status(); - Status(const Status &s); + Status(const Status& s); - Status & - operator=(const Status &s); + Status& + operator=(const Status& s); - Status(Status &&s); + Status(Status&& s); - Status & - operator=(Status &&s); + Status& + operator=(Status&& s); static Status OK() { @@ -55,7 +53,7 @@ class Status { StatusCode code() const { - return (state_ == nullptr) ? 0 : *(StatusCode *) (state_); + return (state_ == nullptr) ? 0 : *(StatusCode*)(state_); } std::string @@ -66,14 +64,13 @@ class Status { private: inline void - CopyFrom(const Status &s); + CopyFrom(const Status& s); inline void - MoveFrom(Status &s); + MoveFrom(Status& s); private: - const char *state_ = nullptr; -}; // Status + char* state_ = nullptr; +}; // Status -} // namespace milvus -} // namespace zilliz +} // namespace milvus diff --git a/cpp/src/utils/StringHelpFunctions.cpp b/cpp/src/utils/StringHelpFunctions.cpp index 8c9e888d3ab1bd575210140039d9af1d4a52b331..230cc1a0ff57dcf815aa805c605e3b91f1f42a81 100644 --- a/cpp/src/utils/StringHelpFunctions.cpp +++ b/cpp/src/utils/StringHelpFunctions.cpp @@ -19,12 +19,11 @@ #include -namespace zilliz { namespace milvus { namespace server { void -StringHelpFunctions::TrimStringBlank(std::string &string) { +StringHelpFunctions::TrimStringBlank(std::string& string) { if (!string.empty()) { static std::string s_format(" \n\r\t"); string.erase(0, string.find_first_not_of(s_format)); @@ -33,7 +32,7 @@ StringHelpFunctions::TrimStringBlank(std::string &string) { } void -StringHelpFunctions::TrimStringQuote(std::string &string, const std::string &qoute) { +StringHelpFunctions::TrimStringQuote(std::string& string, const std::string& qoute) { if (!string.empty()) { string.erase(0, string.find_first_not_of(qoute)); string.erase(string.find_last_not_of(qoute) + 1); @@ -41,9 +40,8 @@ StringHelpFunctions::TrimStringQuote(std::string &string, const std::string &qou } Status -StringHelpFunctions::SplitStringByDelimeter(const std::string &str, - const std::string &delimeter, - std::vector &result) { +StringHelpFunctions::SplitStringByDelimeter(const std::string& str, const std::string& delimeter, + std::vector& result) { if (str.empty()) { return Status::OK(); } @@ -64,10 +62,8 @@ StringHelpFunctions::SplitStringByDelimeter(const std::string &str, } Status -StringHelpFunctions::SplitStringByQuote(const std::string &str, - const std::string &delimeter, - const std::string "e, - std::vector &result) { +StringHelpFunctions::SplitStringByQuote(const std::string& str, const std::string& delimeter, const std::string& quote, + std::vector& result) { if (quote.empty()) { return SplitStringByDelimeter(str, delimeter, result); } @@ -126,6 +122,5 @@ StringHelpFunctions::SplitStringByQuote(const std::string &str, return Status::OK(); } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/StringHelpFunctions.h b/cpp/src/utils/StringHelpFunctions.h index ebffa075ad1ea697dba380a19f009e5b951fcea7..cb355332f115a0c4087d8cbb9cf020567160c911 100644 --- a/cpp/src/utils/StringHelpFunctions.h +++ b/cpp/src/utils/StringHelpFunctions.h @@ -19,10 +19,9 @@ #include "utils/Status.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace server { @@ -31,34 +30,33 @@ class StringHelpFunctions { StringHelpFunctions() = default; public: - static void TrimStringBlank(std::string &string); + static void + TrimStringBlank(std::string& string); - static void TrimStringQuote(std::string &string, const std::string &qoute); + static void + TrimStringQuote(std::string& string, const std::string& qoute); - //split string by delimeter ',' + // split string by delimeter ',' // a,b,c a | b | c // a,b, a | b | // ,b,c | b | c // ,b, | b | // ,, | | // a a - static Status SplitStringByDelimeter(const std::string &str, - const std::string &delimeter, - std::vector &result); + static Status + SplitStringByDelimeter(const std::string& str, const std::string& delimeter, std::vector& result); - //assume the table has two columns, quote='\"', delimeter=',' + // assume the table has two columns, quote='\"', delimeter=',' // a,b a | b // "aa,gg,yy",b aa,gg,yy | b // aa"dd,rr"kk,pp aadd,rrkk | pp // "aa,bb" aa,bb // 55,1122\"aa,bb\",yyy,\"kkk\" 55 | 1122aa,bb | yyy | kkk // "55,1122"aa,bb",yyy,"kkk" illegal - static Status SplitStringByQuote(const std::string &str, - const std::string &delimeter, - const std::string "e, - std::vector &result); + static Status + SplitStringByQuote(const std::string& str, const std::string& delimeter, const std::string& quote, + std::vector& result); }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/ThreadPool.h b/cpp/src/utils/ThreadPool.h index 0524f6206dffee97b888e97efaef3161a1796721..d605d70018e8606abfed9f83b352ba61276c717e 100644 --- a/cpp/src/utils/ThreadPool.h +++ b/cpp/src/utils/ThreadPool.h @@ -17,29 +17,28 @@ #pragma once -#include -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include #include +#include #include +#include -#define MAX_THREADS_NUM 32 +#define MAX_THREADS_NUM 32 -namespace zilliz { namespace milvus { class ThreadPool { public: explicit ThreadPool(size_t threads, size_t queue_size = 1000); - template - auto enqueue(F &&f, Args &&... args) - -> std::future::type>; + template + auto + enqueue(F&& f, Args&&... args) -> std::future::type>; ~ThreadPool(); @@ -61,37 +60,31 @@ class ThreadPool { }; // the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads, size_t queue_size) - : max_queue_size_(queue_size), stop(false) { +inline ThreadPool::ThreadPool(size_t threads, size_t queue_size) : max_queue_size_(queue_size), stop(false) { for (size_t i = 0; i < threads; ++i) - workers_.emplace_back( - [this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex_); - this->condition_.wait(lock, - [this] { - return this->stop || !this->tasks_.empty(); - }); - if (this->stop && this->tasks_.empty()) - return; - task = std::move(this->tasks_.front()); - this->tasks_.pop(); - } - this->condition_.notify_all(); - - task(); + workers_.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex_); + this->condition_.wait(lock, [this] { return this->stop || !this->tasks_.empty(); }); + if (this->stop && this->tasks_.empty()) + return; + task = std::move(this->tasks_.front()); + this->tasks_.pop(); } - }); + this->condition_.notify_all(); + + task(); + } + }); } // add new work item to the pool -template +template auto -ThreadPool::enqueue(F &&f, Args &&... args) --> std::future::type> { +ThreadPool::enqueue(F&& f, Args&&... args) -> std::future::type> { using return_type = typename std::result_of::type; auto task = std::make_shared >( @@ -100,17 +93,12 @@ ThreadPool::enqueue(F &&f, Args &&... args) std::future res = task->get_future(); { std::unique_lock lock(queue_mutex_); - this->condition_.wait(lock, - [this] { - return this->tasks_.size() < max_queue_size_; - }); + this->condition_.wait(lock, [this] { return this->tasks_.size() < max_queue_size_; }); // don't allow enqueueing after stopping the pool if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); - tasks_.emplace([task]() { - (*task)(); - }); + tasks_.emplace([task]() { (*task)(); }); } condition_.notify_all(); return res; @@ -123,11 +111,9 @@ inline ThreadPool::~ThreadPool() { stop = true; } condition_.notify_all(); - for (std::thread &worker : workers_) { + for (std::thread& worker : workers_) { worker.join(); } } -} // namespace milvus -} // namespace zilliz - +} // namespace milvus diff --git a/cpp/src/utils/TimeRecorder.cpp b/cpp/src/utils/TimeRecorder.cpp index 5246b35f13413039ab052eea03c4884b16f46a62..f3061d9d2b5d28d3919342948575460dbc202a67 100644 --- a/cpp/src/utils/TimeRecorder.cpp +++ b/cpp/src/utils/TimeRecorder.cpp @@ -18,13 +18,9 @@ #include "utils/TimeRecorder.h" #include "utils/Log.h" -namespace zilliz { namespace milvus { -TimeRecorder::TimeRecorder(const std::string &header, - int64_t log_level) : - header_(header), - log_level_(log_level) { +TimeRecorder::TimeRecorder(const std::string& header, int64_t log_level) : header_(header), log_level_(log_level) { start_ = last_ = stdclock::now(); } @@ -40,9 +36,10 @@ TimeRecorder::GetTimeSpanStr(double span) { } void -TimeRecorder::PrintTimeRecord(const std::string &msg, double span) { +TimeRecorder::PrintTimeRecord(const std::string& msg, double span) { std::string str_log; - if (!header_.empty()) str_log += header_ + ": "; + if (!header_.empty()) + str_log += header_ + ": "; str_log += msg; str_log += " ("; str_log += TimeRecorder::GetTimeSpanStr(span); @@ -81,7 +78,7 @@ TimeRecorder::PrintTimeRecord(const std::string &msg, double span) { } double -TimeRecorder::RecordSection(const std::string &msg) { +TimeRecorder::RecordSection(const std::string& msg) { stdclock::time_point curr = stdclock::now(); double span = (std::chrono::duration(curr - last_)).count(); last_ = curr; @@ -91,7 +88,7 @@ TimeRecorder::RecordSection(const std::string &msg) { } double -TimeRecorder::ElapseFromBegin(const std::string &msg) { +TimeRecorder::ElapseFromBegin(const std::string& msg) { stdclock::time_point curr = stdclock::now(); double span = (std::chrono::duration(curr - start_)).count(); @@ -99,5 +96,4 @@ TimeRecorder::ElapseFromBegin(const std::string &msg) { return span; } -} // namespace milvus -} // namespace zilliz +} // namespace milvus diff --git a/cpp/src/utils/TimeRecorder.h b/cpp/src/utils/TimeRecorder.h index 2bb937e71f09548d43241285849093bb01951be4..cc0a86fbe00a7aa87ee315281942895bac3c516a 100644 --- a/cpp/src/utils/TimeRecorder.h +++ b/cpp/src/utils/TimeRecorder.h @@ -17,29 +17,31 @@ #pragma once -#include #include +#include -namespace zilliz { namespace milvus { class TimeRecorder { using stdclock = std::chrono::high_resolution_clock; public: - TimeRecorder(const std::string &header, - int64_t log_level = 1); + explicit TimeRecorder(const std::string& header, int64_t log_level = 1); - ~TimeRecorder();//trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5 + ~TimeRecorder(); // trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5 - double RecordSection(const std::string &msg); + double + RecordSection(const std::string& msg); - double ElapseFromBegin(const std::string &msg); + double + ElapseFromBegin(const std::string& msg); - static std::string GetTimeSpanStr(double span); + static std::string + GetTimeSpanStr(double span); private: - void PrintTimeRecord(const std::string &msg, double span); + void + PrintTimeRecord(const std::string& msg, double span); private: std::string header_; @@ -48,5 +50,4 @@ class TimeRecorder { int64_t log_level_; }; -} // namespace milvus -} // namespace zilliz +} // namespace milvus diff --git a/cpp/src/utils/ValidationUtil.cpp b/cpp/src/utils/ValidationUtil.cpp index 8d85c64d31057c823e04d832e080fb5c1bf9dea9..b982a31f5e18cfa36bea170ae068715d5476f37b 100644 --- a/cpp/src/utils/ValidationUtil.cpp +++ b/cpp/src/utils/ValidationUtil.cpp @@ -15,28 +15,26 @@ // specific language governing permissions and limitations // under the License. - #include "utils/ValidationUtil.h" -#include "db/engine/ExecutionEngine.h" #include "Log.h" +#include "db/engine/ExecutionEngine.h" -#include -#include #include -#include +#include #include #include +#include +#include -namespace zilliz { namespace milvus { namespace server { constexpr size_t TABLE_NAME_SIZE_LIMIT = 255; constexpr int64_t TABLE_DIMENSION_LIMIT = 16384; -constexpr int32_t INDEX_FILE_SIZE_LIMIT = 4096; //index trigger size max = 4096 MB +constexpr int32_t INDEX_FILE_SIZE_LIMIT = 4096; // index trigger size max = 4096 MB Status -ValidationUtil::ValidateTableName(const std::string &table_name) { +ValidationUtil::ValidateTableName(const std::string& table_name) { // Table name shouldn't be empty. if (table_name.empty()) { std::string msg = "Empty table name"; @@ -74,7 +72,11 @@ ValidationUtil::ValidateTableName(const std::string &table_name) { Status ValidationUtil::ValidateTableDimension(int64_t dimension) { - if (dimension <= 0 || dimension > TABLE_DIMENSION_LIMIT) { + if (dimension <= 0) { + std::string msg = "Dimension value should be greater than 0"; + SERVER_LOG_ERROR << msg; + return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); + } else if (dimension > TABLE_DIMENSION_LIMIT) { std::string msg = "Table dimension excceed the limitation: " + std::to_string(TABLE_DIMENSION_LIMIT); SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); @@ -85,8 +87,8 @@ ValidationUtil::ValidateTableDimension(int64_t dimension) { Status ValidationUtil::ValidateTableIndexType(int32_t index_type) { - int engine_type = (int) engine::EngineType(index_type); - if (engine_type <= 0 || engine_type > (int) engine::EngineType::MAX_VALUE) { + int engine_type = static_cast(engine::EngineType(index_type)); + if (engine_type <= 0 || engine_type > static_cast(engine::EngineType::MAX_VALUE)) { std::string msg = "Invalid index type: " + std::to_string(index_type); SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_INDEX_TYPE, msg); @@ -98,7 +100,7 @@ ValidationUtil::ValidateTableIndexType(int32_t index_type) { Status ValidationUtil::ValidateTableIndexNlist(int32_t nlist) { if (nlist <= 0) { - std::string msg = "Invalid nlist value: " + std::to_string(nlist); + std::string msg = "nlist value should be greater than 0"; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_INDEX_NLIST, msg); } @@ -119,7 +121,8 @@ ValidationUtil::ValidateTableIndexFileSize(int64_t index_file_size) { Status ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) { - if (metric_type != (int32_t) engine::MetricType::L2 && metric_type != (int32_t) engine::MetricType::IP) { + if (metric_type != static_cast(engine::MetricType::L2) && + metric_type != static_cast(engine::MetricType::IP)) { std::string msg = "Invalid metric type: " + std::to_string(metric_type); SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_INDEX_METRIC_TYPE, msg); @@ -128,9 +131,9 @@ ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) { } Status -ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema &table_schema) { +ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema) { if (top_k <= 0 || top_k > 2048) { - std::string msg = "Invalid top k value: " + std::to_string(top_k); + std::string msg = "Invalid top k value: " + std::to_string(top_k) + ", rational range [1, 2048]"; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_TOPK, msg); } @@ -139,9 +142,10 @@ ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchem } Status -ValidationUtil::ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema &table_schema) { +ValidationUtil::ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema) { if (nprobe <= 0 || nprobe > table_schema.nlist_) { - std::string msg = "Invalid nprobe value: " + std::to_string(nprobe); + std::string msg = "Invalid nprobe value: " + std::to_string(nprobe) + ", rational range [1, " + + std::to_string(table_schema.nlist_) + "]"; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_NPROBE, msg); } @@ -153,7 +157,7 @@ Status ValidationUtil::ValidateGpuIndex(uint32_t gpu_index) { int num_devices = 0; auto cuda_err = cudaGetDeviceCount(&num_devices); - if (cuda_err) { + if (cuda_err != cudaSuccess) { std::string msg = "Failed to get gpu card number, cuda error:" + std::to_string(cuda_err); SERVER_LOG_ERROR << msg; return Status(SERVER_UNEXPECTED_ERROR, msg); @@ -169,7 +173,7 @@ ValidationUtil::ValidateGpuIndex(uint32_t gpu_index) { } Status -ValidationUtil::GetGpuMemory(uint32_t gpu_index, size_t &memory) { +ValidationUtil::GetGpuMemory(uint32_t gpu_index, size_t& memory) { cudaDeviceProp deviceProp; auto cuda_err = cudaGetDeviceProperties(&deviceProp, gpu_index); if (cuda_err) { @@ -183,13 +187,14 @@ ValidationUtil::GetGpuMemory(uint32_t gpu_index, size_t &memory) { } Status -ValidationUtil::ValidateIpAddress(const std::string &ip_address) { +ValidationUtil::ValidateIpAddress(const std::string& ip_address) { struct in_addr address; int result = inet_pton(AF_INET, ip_address.c_str(), &address); switch (result) { - case 1:return Status::OK(); + case 1: + return Status::OK(); case 0: { std::string msg = "Invalid IP address: " + ip_address; SERVER_LOG_ERROR << msg; @@ -204,7 +209,7 @@ ValidationUtil::ValidateIpAddress(const std::string &ip_address) { } Status -ValidationUtil::ValidateStringIsNumber(const std::string &str) { +ValidationUtil::ValidateStringIsNumber(const std::string& str) { if (str.empty() || !std::all_of(str.begin(), str.end(), ::isdigit)) { return Status(SERVER_INVALID_ARGUMENT, "Invalid number"); } @@ -217,20 +222,18 @@ ValidationUtil::ValidateStringIsNumber(const std::string &str) { } Status -ValidationUtil::ValidateStringIsBool(const std::string &str) { +ValidationUtil::ValidateStringIsBool(const std::string& str) { std::string s = str; std::transform(s.begin(), s.end(), s.begin(), ::tolower); - if (s == "true" || s == "on" || s == "yes" || s == "1" || - s == "false" || s == "off" || s == "no" || s == "0" || + if (s == "true" || s == "on" || s == "yes" || s == "1" || s == "false" || s == "off" || s == "no" || s == "0" || s.empty()) { return Status::OK(); - } else { - return Status(SERVER_INVALID_ARGUMENT, "Invalid boolean: " + str); } + return Status(SERVER_INVALID_ARGUMENT, "Invalid boolean: " + str); } Status -ValidationUtil::ValidateStringIsFloat(const std::string &str) { +ValidationUtil::ValidateStringIsFloat(const std::string& str) { try { float val = std::stof(str); } catch (...) { @@ -240,19 +243,15 @@ ValidationUtil::ValidateStringIsFloat(const std::string &str) { } Status -ValidationUtil::ValidateDbURI(const std::string &uri) { +ValidationUtil::ValidateDbURI(const std::string& uri) { std::string dialectRegex = "(.*)"; std::string usernameRegex = "(.*)"; std::string passwordRegex = "(.*)"; std::string hostRegex = "(.*)"; std::string portRegex = "(.*)"; std::string dbNameRegex = "(.*)"; - std::string uriRegexStr = dialectRegex + "\\:\\/\\/" + - usernameRegex + "\\:" + - passwordRegex + "\\@" + - hostRegex + "\\:" + - portRegex + "\\/" + - dbNameRegex; + std::string uriRegexStr = dialectRegex + "\\:\\/\\/" + usernameRegex + "\\:" + passwordRegex + "\\@" + hostRegex + + "\\:" + portRegex + "\\/" + dbNameRegex; std::regex uriRegex(uriRegexStr); std::smatch pieces_match; @@ -266,17 +265,17 @@ ValidationUtil::ValidateDbURI(const std::string &uri) { okay = false; } -/* - * Could be DNS, skip checking - * - std::string host = pieces_match[4].str(); - if (!host.empty() && host != "localhost") { - if (ValidateIpAddress(host) != SERVER_SUCCESS) { - SERVER_LOG_ERROR << "Invalid host ip address in uri = " << host; - okay = false; - } - } -*/ + /* + * Could be DNS, skip checking + * + std::string host = pieces_match[4].str(); + if (!host.empty() && host != "localhost") { + if (ValidateIpAddress(host) != SERVER_SUCCESS) { + SERVER_LOG_ERROR << "Invalid host ip address in uri = " << host; + okay = false; + } + } + */ std::string port = pieces_match[5].str(); if (!port.empty()) { @@ -294,6 +293,5 @@ ValidationUtil::ValidateDbURI(const std::string &uri) { return (okay ? Status::OK() : Status(SERVER_INVALID_ARGUMENT, "Invalid db backend uri")); } -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/utils/ValidationUtil.h b/cpp/src/utils/ValidationUtil.h index 44d6065a649ddc347ecab7ef02f5cdfb80c698be..7b24c93fb501f95c2ab9e4360ea271c6768ae15f 100644 --- a/cpp/src/utils/ValidationUtil.h +++ b/cpp/src/utils/ValidationUtil.h @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "db/meta/MetaTypes.h" @@ -23,7 +22,6 @@ #include -namespace zilliz { namespace milvus { namespace server { @@ -33,7 +31,7 @@ class ValidationUtil { public: static Status - ValidateTableName(const std::string &table_name); + ValidateTableName(const std::string& table_name); static Status ValidateTableDimension(int64_t dimension); @@ -51,33 +49,32 @@ class ValidationUtil { ValidateTableIndexMetricType(int32_t metric_type); static Status - ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema &table_schema); + ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema); static Status - ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema &table_schema); + ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema); static Status ValidateGpuIndex(uint32_t gpu_index); static Status - GetGpuMemory(uint32_t gpu_index, size_t &memory); + GetGpuMemory(uint32_t gpu_index, size_t& memory); static Status - ValidateIpAddress(const std::string &ip_address); + ValidateIpAddress(const std::string& ip_address); static Status - ValidateStringIsNumber(const std::string &str); + ValidateStringIsNumber(const std::string& str); static Status - ValidateStringIsBool(const std::string &str); + ValidateStringIsBool(const std::string& str); static Status - ValidateStringIsFloat(const std::string &str); + ValidateStringIsFloat(const std::string& str); static Status - ValidateDbURI(const std::string &uri); + ValidateDbURI(const std::string& uri); }; -} // namespace server -} // namespace milvus -} // namespace zilliz +} // namespace server +} // namespace milvus diff --git a/cpp/src/wrapper/ConfAdapter.cpp b/cpp/src/wrapper/ConfAdapter.cpp index ea6be1acf3c0831a36ccf43d6d41363b2df7a5db..2dcf6bab7e1a0c9ccb354a5bcd46ebce9dffca27 100644 --- a/cpp/src/wrapper/ConfAdapter.cpp +++ b/cpp/src/wrapper/ConfAdapter.cpp @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. +#include "wrapper/ConfAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "utils/Log.h" #include -#include "ConfAdapter.h" -#include "src/utils/Log.h" -#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include -// TODO: add conf checker +// TODO(lxj): add conf checker -namespace zilliz { namespace milvus { namespace engine { @@ -42,7 +42,7 @@ ConfAdapter::MatchBase(knowhere::Config conf) { } knowhere::Config -ConfAdapter::Match(const TempMetaConf &metaconf) { +ConfAdapter::Match(const TempMetaConf& metaconf) { auto conf = std::make_shared(); conf->d = metaconf.dim; conf->metric_type = metaconf.metric_type; @@ -52,14 +52,14 @@ ConfAdapter::Match(const TempMetaConf &metaconf) { } knowhere::Config -ConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) { +ConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) { auto conf = std::make_shared(); conf->k = metaconf.k; return conf; } knowhere::Config -IVFConfAdapter::Match(const TempMetaConf &metaconf) { +IVFConfAdapter::Match(const TempMetaConf& metaconf) { auto conf = std::make_shared(); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->d = metaconf.dim; @@ -72,7 +72,7 @@ IVFConfAdapter::Match(const TempMetaConf &metaconf) { static constexpr float TYPICAL_COUNT = 1000000.0; int64_t -IVFConfAdapter::MatchNlist(const int64_t &size, const int64_t &nlist) { +IVFConfAdapter::MatchNlist(const int64_t& size, const int64_t& nlist) { if (size <= TYPICAL_COUNT / 16384 + 1) { // handle less row count, avoid nlist set to 0 return 1; @@ -80,11 +80,11 @@ IVFConfAdapter::MatchNlist(const int64_t &size, const int64_t &nlist) { // calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT return int(size / TYPICAL_COUNT * 16384); } - return 0; + return nlist; } knowhere::Config -IVFConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) { +IVFConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) { auto conf = std::make_shared(); conf->k = metaconf.k; conf->nprobe = metaconf.nprobe; @@ -95,17 +95,16 @@ IVFConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) case IndexType::FAISS_IVFPQ_GPU: if (conf->nprobe > GPU_MAX_NRPOBE) { WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE - << ", but you passed " << conf->nprobe - << ". Search with " << GPU_MAX_NRPOBE << " instead"; + << ", but you passed " << conf->nprobe << ". Search with " << GPU_MAX_NRPOBE + << " instead"; conf->nprobe = GPU_MAX_NRPOBE; - } } return conf; } knowhere::Config -IVFSQConfAdapter::Match(const TempMetaConf &metaconf) { +IVFSQConfAdapter::Match(const TempMetaConf& metaconf) { auto conf = std::make_shared(); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->d = metaconf.dim; @@ -117,7 +116,7 @@ IVFSQConfAdapter::Match(const TempMetaConf &metaconf) { } knowhere::Config -IVFPQConfAdapter::Match(const TempMetaConf &metaconf) { +IVFPQConfAdapter::Match(const TempMetaConf& metaconf) { auto conf = std::make_shared(); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->d = metaconf.dim; @@ -130,7 +129,7 @@ IVFPQConfAdapter::Match(const TempMetaConf &metaconf) { } knowhere::Config -NSGConfAdapter::Match(const TempMetaConf &metaconf) { +NSGConfAdapter::Match(const TempMetaConf& metaconf) { auto conf = std::make_shared(); conf->nlist = MatchNlist(metaconf.size, metaconf.nlist); conf->d = metaconf.dim; @@ -145,17 +144,20 @@ NSGConfAdapter::Match(const TempMetaConf &metaconf) { conf->out_degree = 50 + 5 * scale_factor; conf->candidate_pool_size = 200 + 100 * scale_factor; MatchBase(conf); + + // WRAPPER_LOG_DEBUG << "nlist: " << conf->nlist + // << ", gpu_id: " << conf->gpu_id << ", d: " << conf->d + // << ", nprobe: " << conf->nprobe << ", knng: " << conf->knng; return conf; } knowhere::Config -NSGConfAdapter::MatchSearch(const TempMetaConf &metaconf, const IndexType &type) { +NSGConfAdapter::MatchSearch(const TempMetaConf& metaconf, const IndexType& type) { auto conf = std::make_shared(); conf->k = metaconf.k; conf->search_length = metaconf.search_length; return conf; } -} -} -} +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/ConfAdapter.h b/cpp/src/wrapper/ConfAdapter.h index d0f032cb61a080400915ec0db7023f8c7241ee50..4c8e528a2d63c8e9898ef0c1c906732e30e21e2f 100644 --- a/cpp/src/wrapper/ConfAdapter.h +++ b/cpp/src/wrapper/ConfAdapter.h @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "knowhere/common/Config.h" #include "VecIndex.h" +#include "knowhere/common/Config.h" +#include -namespace zilliz { namespace milvus { namespace engine { @@ -42,13 +41,17 @@ struct TempMetaConf { class ConfAdapter { public: virtual knowhere::Config - Match(const TempMetaConf &metaconf); + Match(const TempMetaConf& metaconf); virtual knowhere::Config - MatchSearch(const TempMetaConf &metaconf, const IndexType &type); + MatchSearch(const TempMetaConf& metaconf, const IndexType& type); + + // virtual void + // Dump(){} protected: - static void MatchBase(knowhere::Config conf); + static void + MatchBase(knowhere::Config conf); }; using ConfAdapterPtr = std::shared_ptr; @@ -56,36 +59,36 @@ using ConfAdapterPtr = std::shared_ptr; class IVFConfAdapter : public ConfAdapter { public: knowhere::Config - Match(const TempMetaConf &metaconf) override; + Match(const TempMetaConf& metaconf) override; knowhere::Config - MatchSearch(const TempMetaConf &metaconf, const IndexType &type) override; + MatchSearch(const TempMetaConf& metaconf, const IndexType& type) override; protected: - static int64_t MatchNlist(const int64_t &size, const int64_t &nlist); + static int64_t + MatchNlist(const int64_t& size, const int64_t& nlist); }; class IVFSQConfAdapter : public IVFConfAdapter { public: knowhere::Config - Match(const TempMetaConf &metaconf) override; + Match(const TempMetaConf& metaconf) override; }; class IVFPQConfAdapter : public IVFConfAdapter { public: knowhere::Config - Match(const TempMetaConf &metaconf) override; + Match(const TempMetaConf& metaconf) override; }; class NSGConfAdapter : public IVFConfAdapter { public: knowhere::Config - Match(const TempMetaConf &metaconf) override; + Match(const TempMetaConf& metaconf) override; knowhere::Config - MatchSearch(const TempMetaConf &metaconf, const IndexType &type) final; + MatchSearch(const TempMetaConf& metaconf, const IndexType& type) final; }; -} -} -} +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/ConfAdapterMgr.cpp b/cpp/src/wrapper/ConfAdapterMgr.cpp index cb7331a99d2b509fd25dbec88e8484c80490f3f3..05c23c4238dbfa7a2a1c8a99dc883aa28d5c97f1 100644 --- a/cpp/src/wrapper/ConfAdapterMgr.cpp +++ b/cpp/src/wrapper/ConfAdapterMgr.cpp @@ -15,18 +15,16 @@ // specific language governing permissions and limitations // under the License. +#include "wrapper/ConfAdapterMgr.h" +#include "utils/Exception.h" -#include "src/utils/Exception.h" -#include "ConfAdapterMgr.h" - - -namespace zilliz { namespace milvus { namespace engine { ConfAdapterPtr -AdapterMgr::GetAdapter(const IndexType &indexType) { - if (!init_) RegisterAdapter(); +AdapterMgr::GetAdapter(const IndexType& indexType) { + if (!init_) + RegisterAdapter(); auto it = table_.find(indexType); if (it != table_.end()) { @@ -36,8 +34,8 @@ AdapterMgr::GetAdapter(const IndexType &indexType) { } } +#define REGISTER_CONF_ADAPTER(T, KEY, NAME) static AdapterMgr::register_t reg_##NAME##_(KEY) -#define REGISTER_CONF_ADAPTER(T, KEY, NAME) static AdapterMgr::register_treg_##NAME##_(KEY) void AdapterMgr::RegisterAdapter() { init_ = true; @@ -58,7 +56,5 @@ AdapterMgr::RegisterAdapter() { REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexType::NSG_MIX, nsg_mix); } -} // engine -} // milvus -} // zilliz - +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/ConfAdapterMgr.h b/cpp/src/wrapper/ConfAdapterMgr.h index 3e68a3df26a6a1389a005142aed39ca01ca600a4..8d5fa22877a1d79fe33daec9dccad641503da423 100644 --- a/cpp/src/wrapper/ConfAdapterMgr.h +++ b/cpp/src/wrapper/ConfAdapterMgr.h @@ -15,37 +15,35 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "VecIndex.h" #include "ConfAdapter.h" +#include "VecIndex.h" #include +#include +#include -namespace zilliz { namespace milvus { namespace engine { class AdapterMgr { public: - template + template struct register_t { - explicit register_t(const IndexType &key) { - AdapterMgr::GetInstance().table_.emplace(key, [] { - return std::make_shared(); - }); + explicit register_t(const IndexType& key) { + AdapterMgr::GetInstance().table_.emplace(key, [] { return std::make_shared(); }); } }; - static AdapterMgr & + static AdapterMgr& GetInstance() { static AdapterMgr instance; return instance; } ConfAdapterPtr - GetAdapter(const IndexType &indexType); + GetAdapter(const IndexType& indexType); void RegisterAdapter(); @@ -55,10 +53,5 @@ class AdapterMgr { std::map > table_; }; - -} // engine -} // milvus -} // zilliz - - - +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/DataTransfer.cpp b/cpp/src/wrapper/DataTransfer.cpp index c6603b1991d4a39d02739405702b02c610e73d19..5eb83290d18b259f27506313984eed8833ce4e66 100644 --- a/cpp/src/wrapper/DataTransfer.cpp +++ b/cpp/src/wrapper/DataTransfer.cpp @@ -15,39 +15,37 @@ // specific language governing permissions and limitations // under the License. - #include "wrapper/DataTransfer.h" -#include #include #include +#include -namespace zilliz { namespace milvus { namespace engine { knowhere::DatasetPtr -GenDatasetWithIds(const int64_t &nb, const int64_t &dim, const float *xb, const int64_t *ids) { +GenDatasetWithIds(const int64_t& nb, const int64_t& dim, const float* xb, const int64_t* ids) { std::vector shape{nb, dim}; - auto tensor = knowhere::ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape); + auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape); std::vector tensors{tensor}; std::vector tensor_fields{knowhere::ConstructFloatField("data")}; auto tensor_schema = std::make_shared(tensor_fields); - auto id_array = knowhere::ConstructInt64Array((uint8_t *) ids, nb * sizeof(int64_t)); + auto id_array = knowhere::ConstructInt64Array((uint8_t*)ids, nb * sizeof(int64_t)); std::vector arrays{id_array}; std::vector array_fields{knowhere::ConstructInt64Field("id")}; auto array_schema = std::make_shared(tensor_fields); - auto dataset = std::make_shared(std::move(arrays), array_schema, - std::move(tensors), tensor_schema); + auto dataset = + std::make_shared(std::move(arrays), array_schema, std::move(tensors), tensor_schema); return dataset; } knowhere::DatasetPtr -GenDataset(const int64_t &nb, const int64_t &dim, const float *xb) { +GenDataset(const int64_t& nb, const int64_t& dim, const float* xb) { std::vector shape{nb, dim}; - auto tensor = knowhere::ConstructFloatTensor((uint8_t *) xb, nb * dim * sizeof(float), shape); + auto tensor = knowhere::ConstructFloatTensor((uint8_t*)xb, nb * dim * sizeof(float), shape); std::vector tensors{tensor}; std::vector tensor_fields{knowhere::ConstructFloatField("data")}; auto tensor_schema = std::make_shared(tensor_fields); @@ -56,6 +54,5 @@ GenDataset(const int64_t &nb, const int64_t &dim, const float *xb) { return dataset; } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/DataTransfer.h b/cpp/src/wrapper/DataTransfer.h index 070adca0555488dd22bf2ad075a33ee821611fc2..e945eaa6dbb0f275e19d6edcbc36d634a6b62f5a 100644 --- a/cpp/src/wrapper/DataTransfer.h +++ b/cpp/src/wrapper/DataTransfer.h @@ -15,21 +15,18 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "knowhere/adapter/Structure.h" -namespace zilliz { namespace milvus { namespace engine { -extern zilliz::knowhere::DatasetPtr -GenDatasetWithIds(const int64_t &nb, const int64_t &dim, const float *xb, const int64_t *ids); +extern knowhere::DatasetPtr +GenDatasetWithIds(const int64_t& nb, const int64_t& dim, const float* xb, const int64_t* ids); -extern zilliz::knowhere::DatasetPtr -GenDataset(const int64_t &nb, const int64_t &dim, const float *xb); +extern knowhere::DatasetPtr +GenDataset(const int64_t& nb, const int64_t& dim, const float* xb); -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/KnowhereResource.cpp b/cpp/src/wrapper/KnowhereResource.cpp index b46acb1a4acb7a9610bc5cff08540d4b57c45378..d291bb92998bd3d20b1c04351b71af8d2fe7f20e 100644 --- a/cpp/src/wrapper/KnowhereResource.cpp +++ b/cpp/src/wrapper/KnowhereResource.cpp @@ -15,18 +15,16 @@ // specific language governing permissions and limitations // under the License. - #include "wrapper/KnowhereResource.h" #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #include "server/Config.h" #include #include -#include #include #include +#include -namespace zilliz { namespace milvus { namespace engine { @@ -43,22 +41,24 @@ KnowhereResource::Initialize() { GpuResourcesArray gpu_resources; Status s; - //get build index gpu resource - server::Config &config = server::Config::GetInstance(); + // get build index gpu resource + server::Config& config = server::Config::GetInstance(); int32_t build_index_gpu; s = config.GetDBConfigBuildIndexGPU(build_index_gpu); - if (!s.ok()) return s; + if (!s.ok()) + return s; gpu_resources.insert(std::make_pair(build_index_gpu, GpuResourceSetting())); - //get search gpu resource + // get search gpu resource std::vector pool; s = config.GetResourceConfigPool(pool); - if (!s.ok()) return s; + if (!s.ok()) + return s; std::set gpu_ids; - for (auto &resource : pool) { + for (auto& resource : pool) { if (resource.length() < 4 || resource.substr(0, 3) != "gpu") { // invalid continue; @@ -67,12 +67,10 @@ KnowhereResource::Initialize() { gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting())); } - //init gpu resources + // init gpu resources for (auto iter = gpu_resources.begin(); iter != gpu_resources.end(); ++iter) { - knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(iter->first, - iter->second.pinned_memory, - iter->second.temp_memory, - iter->second.resource_num); + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(iter->first, iter->second.pinned_memory, + iter->second.temp_memory, iter->second.resource_num); } return Status::OK(); @@ -80,10 +78,9 @@ KnowhereResource::Initialize() { Status KnowhereResource::Finalize() { - knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource. + knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource. return Status::OK(); } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/KnowhereResource.h b/cpp/src/wrapper/KnowhereResource.h index a8726f9542202eab6de0f6fba9ebaac1c6a200a7..dff8b32c0b90c8e492afcbbbd50b8ac03cc3747d 100644 --- a/cpp/src/wrapper/KnowhereResource.h +++ b/cpp/src/wrapper/KnowhereResource.h @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. - #pragma once #include "utils/Status.h" -namespace zilliz { namespace milvus { namespace engine { @@ -33,6 +31,5 @@ class KnowhereResource { Finalize(); }; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/VecImpl.cpp b/cpp/src/wrapper/VecImpl.cpp index 17b9ec61d361e9591f54f1ed04ddf301772c2518..1ed20c8029b76408078fe76b57c33d9f22530e80 100644 --- a/cpp/src/wrapper/VecImpl.cpp +++ b/cpp/src/wrapper/VecImpl.cpp @@ -15,32 +15,26 @@ // specific language governing permissions and limitations // under the License. - #include "wrapper/VecImpl.h" -#include "utils/Log.h" +#include "DataTransfer.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexGPUIVF.h" #include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/IndexIVFSQHybrid.h" -#include "knowhere/index/vector_index/IndexGPUIVF.h" -#include "knowhere/common/Exception.h" #include "knowhere/index/vector_index/helpers/Cloner.h" -#include "DataTransfer.h" +#include "utils/Log.h" /* * no parameter check in this layer. * only responible for index combination */ -namespace zilliz { namespace milvus { namespace engine { Status -VecIndexImpl::BuildAll(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg, - const int64_t &nt, - const float *xt) { +VecIndexImpl::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt, + const float* xt) { try { dim = cfg->d; auto dataset = GenDatasetWithIds(nb, dim, xb, ids); @@ -50,10 +44,10 @@ VecIndexImpl::BuildAll(const int64_t &nb, auto model = index_->Train(dataset, cfg); index_->set_index_model(model); index_->Add(dataset, cfg); - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } @@ -61,15 +55,15 @@ VecIndexImpl::BuildAll(const int64_t &nb, } Status -VecIndexImpl::Add(const int64_t &nb, const float *xb, const int64_t *ids, const Config &cfg) { +VecIndexImpl::Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg) { try { auto dataset = GenDatasetWithIds(nb, dim, xb, ids); index_->Add(dataset, cfg); - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } @@ -77,7 +71,7 @@ VecIndexImpl::Add(const int64_t &nb, const float *xb, const int64_t *ids, const } Status -VecIndexImpl::Search(const int64_t &nq, const float *xq, float *dist, int64_t *ids, const Config &cfg) { +VecIndexImpl::Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg) { try { auto k = cfg->k; auto dataset = GenDataset(nq, dim, xq); @@ -111,24 +105,24 @@ VecIndexImpl::Search(const int64_t &nq, const float *xq, float *dist, int64_t *i // TODO(linxj): avoid copy here. memcpy(ids, p_ids, sizeof(int64_t) * nq * k); memcpy(dist, p_dist, sizeof(float) * nq * k); - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } return Status::OK(); } -zilliz::knowhere::BinarySet +knowhere::BinarySet VecIndexImpl::Serialize() { type = ConvertToCpuIndexType(type); return index_->Serialize(); } Status -VecIndexImpl::Load(const zilliz::knowhere::BinarySet &index_binary) { +VecIndexImpl::Load(const knowhere::BinarySet& index_binary) { index_->Load(index_binary); dim = Dimension(); return Status::OK(); @@ -150,18 +144,18 @@ VecIndexImpl::GetType() { } VecIndexPtr -VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { +VecIndexImpl::CopyToGpu(const int64_t& device_id, const Config& cfg) { // TODO(linxj): exception handle - auto gpu_index = zilliz::knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg); + auto gpu_index = knowhere::cloner::CopyCpuToGpu(index_, device_id, cfg); auto new_index = std::make_shared(gpu_index, ConvertToGpuIndexType(type)); new_index->dim = dim; return new_index; } VecIndexPtr -VecIndexImpl::CopyToCpu(const Config &cfg) { +VecIndexImpl::CopyToCpu(const Config& cfg) { // TODO(linxj): exception handle - auto cpu_index = zilliz::knowhere::cloner::CopyGpuToCpu(index_, cfg); + auto cpu_index = knowhere::cloner::CopyGpuToCpu(index_, cfg); auto new_index = std::make_shared(cpu_index, ConvertToCpuIndexType(type)); new_index->dim = dim; return new_index; @@ -181,30 +175,32 @@ VecIndexImpl::GetDeviceId() { return device_idx->GetGpuDevice(); } // else - return -1; // -1 == cpu + return -1; // -1 == cpu } -float * +float* BFIndex::GetRawVectors() { auto raw_index = std::dynamic_pointer_cast(index_); - if (raw_index) { return raw_index->GetRawVectors(); } + if (raw_index) { + return raw_index->GetRawVectors(); + } return nullptr; } -int64_t * +int64_t* BFIndex::GetRawIds() { return std::static_pointer_cast(index_)->GetRawIds(); } ErrorCode -BFIndex::Build(const Config &cfg) { +BFIndex::Build(const Config& cfg) { try { dim = cfg->d; std::static_pointer_cast(index_)->Train(cfg); - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return KNOWHERE_UNEXPECTED_ERROR; - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return KNOWHERE_ERROR; } @@ -212,22 +208,18 @@ BFIndex::Build(const Config &cfg) { } Status -BFIndex::BuildAll(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg, - const int64_t &nt, - const float *xt) { +BFIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt, + const float* xt) { try { dim = cfg->d; auto dataset = GenDatasetWithIds(nb, dim, xb, ids); std::static_pointer_cast(index_)->Train(cfg); index_->Add(dataset, cfg); - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } @@ -236,12 +228,8 @@ BFIndex::BuildAll(const int64_t &nb, // TODO(linxj): add lock here. Status -IVFMixIndex::BuildAll(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg, - const int64_t &nt, - const float *xt) { +IVFMixIndex::BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt, + const float* xt) { try { dim = cfg->d; auto dataset = GenDatasetWithIds(nb, dim, xb, ids); @@ -260,10 +248,10 @@ IVFMixIndex::BuildAll(const int64_t &nb, WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed"; return Status(KNOWHERE_ERROR, "Build IVFMIXIndex Failed"); } - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } @@ -271,7 +259,7 @@ IVFMixIndex::BuildAll(const int64_t &nb, } Status -IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { +IVFMixIndex::Load(const knowhere::BinarySet& index_binary) { index_->Load(index_binary); dim = Dimension(); return Status::OK(); @@ -280,7 +268,7 @@ IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { knowhere::QuantizerPtr IVFHybridIndex::LoadQuantizer(const Config& conf) { // TODO(linxj): Hardcode here - if (auto new_idx = std::dynamic_pointer_cast(index_)){ + if (auto new_idx = std::dynamic_pointer_cast(index_)) { return new_idx->LoadQuantizer(conf); } else { WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type); @@ -297,10 +285,10 @@ IVFHybridIndex::SetQuantizer(const knowhere::QuantizerPtr& q) { WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type); return Status(KNOWHERE_ERROR, "not support"); } - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } @@ -317,17 +305,18 @@ IVFHybridIndex::UnsetQuantizer() { WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type); return Status(KNOWHERE_ERROR, "not support"); } - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } return Status::OK(); } -Status IVFHybridIndex::LoadData(const knowhere::QuantizerPtr &q, const Config &conf) { +Status +IVFHybridIndex::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { try { // TODO(linxj): Hardcode here if (auto new_idx = std::dynamic_pointer_cast(index_)) { @@ -336,16 +325,15 @@ Status IVFHybridIndex::LoadData(const knowhere::QuantizerPtr &q, const Config &c WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type); return Status(KNOWHERE_ERROR, "not support"); } - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_ERROR, e.what()); } return Status::OK(); } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/VecImpl.h b/cpp/src/wrapper/VecImpl.h index c5e2484daac09c1bb9eb671f4b22e45c2bc91951..f35a6ac4cd520320ffb461f8d81f53756d8f2f9d 100644 --- a/cpp/src/wrapper/VecImpl.h +++ b/cpp/src/wrapper/VecImpl.h @@ -15,38 +15,32 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include "knowhere/index/vector_index/VectorIndex.h" #include "VecIndex.h" +#include "knowhere/index/vector_index/VectorIndex.h" -#include #include +#include -namespace zilliz { namespace milvus { namespace engine { class VecIndexImpl : public VecIndex { public: - explicit VecIndexImpl(std::shared_ptr index, const IndexType &type) + explicit VecIndexImpl(std::shared_ptr index, const IndexType& type) : index_(std::move(index)), type(type) { } Status - BuildAll(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg, - const int64_t &nt, - const float *xt) override; + BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt, + const float* xt) override; VecIndexPtr - CopyToGpu(const int64_t &device_id, const Config &cfg) override; + CopyToGpu(const int64_t& device_id, const Config& cfg) override; VecIndexPtr - CopyToCpu(const Config &cfg) override; + CopyToCpu(const Config& cfg) override; IndexType GetType() override; @@ -58,13 +52,13 @@ class VecIndexImpl : public VecIndex { Count() override; Status - Add(const int64_t &nb, const float *xb, const int64_t *ids, const Config &cfg) override; + Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg) override; - zilliz::knowhere::BinarySet + knowhere::BinarySet Serialize() override; Status - Load(const zilliz::knowhere::BinarySet &index_binary) override; + Load(const knowhere::BinarySet& index_binary) override; VecIndexPtr Clone() override; @@ -73,32 +67,28 @@ class VecIndexImpl : public VecIndex { GetDeviceId() override; Status - Search(const int64_t &nq, const float *xq, float *dist, int64_t *ids, const Config &cfg) override; + Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg) override; protected: int64_t dim = 0; IndexType type = IndexType::INVALID; - std::shared_ptr index_ = nullptr; + std::shared_ptr index_ = nullptr; }; class IVFMixIndex : public VecIndexImpl { public: - explicit IVFMixIndex(std::shared_ptr index, const IndexType &type) + explicit IVFMixIndex(std::shared_ptr index, const IndexType& type) : VecIndexImpl(std::move(index), type) { } Status - BuildAll(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg, - const int64_t &nt, - const float *xt) override; + BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt, + const float* xt) override; Status - Load(const zilliz::knowhere::BinarySet &index_binary) override; + Load(const knowhere::BinarySet& index_binary) override; }; class IVFHybridIndex : public IVFMixIndex { @@ -113,33 +103,28 @@ class IVFHybridIndex : public IVFMixIndex { UnsetQuantizer() override; Status - LoadData(const knowhere::QuantizerPtr &q, const Config &conf) override; + LoadData(const knowhere::QuantizerPtr& q, const Config& conf) override; }; class BFIndex : public VecIndexImpl { public: - explicit BFIndex(std::shared_ptr index) : VecIndexImpl(std::move(index), - IndexType::FAISS_IDMAP) { + explicit BFIndex(std::shared_ptr index) + : VecIndexImpl(std::move(index), IndexType::FAISS_IDMAP) { } ErrorCode - Build(const Config &cfg); + Build(const Config& cfg); - float * + float* GetRawVectors(); Status - BuildAll(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg, - const int64_t &nt, - const float *xt) override; - - int64_t * + BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt, + const float* xt) override; + + int64_t* GetRawIds(); }; -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/VecIndex.cpp b/cpp/src/wrapper/VecIndex.cpp index 27bba6149e016c5bf6550d2efbde841c93d6b881..4d9449cb462b987b174319ff268990ae79a6e59d 100644 --- a/cpp/src/wrapper/VecIndex.cpp +++ b/cpp/src/wrapper/VecIndex.cpp @@ -16,23 +16,22 @@ // under the License. #include "wrapper/VecIndex.h" -#include "knowhere/index/vector_index/IndexIVF.h" +#include "VecImpl.h" +#include "knowhere/common/Exception.h" #include "knowhere/index/vector_index/IndexGPUIVF.h" -#include "knowhere/index/vector_index/IndexIVFSQ.h" -#include "knowhere/index/vector_index/IndexGPUIVFSQ.h" -#include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/IndexGPUIVFSQ.h" #include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" #include "knowhere/index/vector_index/IndexKDT.h" #include "knowhere/index/vector_index/IndexNSG.h" -#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" -#include "knowhere/common/Exception.h" -#include "VecImpl.h" #include "utils/Log.h" #include -namespace zilliz { namespace milvus { namespace engine { @@ -40,18 +39,18 @@ struct FileIOReader { std::fstream fs; std::string name; - explicit FileIOReader(const std::string &fname); + explicit FileIOReader(const std::string& fname); ~FileIOReader(); size_t - operator()(void *ptr, size_t size); + operator()(void* ptr, size_t size); size_t - operator()(void *ptr, size_t size, size_t pos); + operator()(void* ptr, size_t size, size_t pos); }; -FileIOReader::FileIOReader(const std::string &fname) { +FileIOReader::FileIOReader(const std::string& fname) { name = fname; fs = std::fstream(name, std::ios::in | std::ios::binary); } @@ -61,12 +60,12 @@ FileIOReader::~FileIOReader() { } size_t -FileIOReader::operator()(void *ptr, size_t size) { - fs.read(reinterpret_cast(ptr), size); +FileIOReader::operator()(void* ptr, size_t size) { + fs.read(reinterpret_cast(ptr), size); } size_t -FileIOReader::operator()(void *ptr, size_t size, size_t pos) { +FileIOReader::operator()(void* ptr, size_t size, size_t pos) { return 0; } @@ -74,12 +73,13 @@ struct FileIOWriter { std::fstream fs; std::string name; - explicit FileIOWriter(const std::string &fname); + explicit FileIOWriter(const std::string& fname); ~FileIOWriter(); - size_t operator()(void *ptr, size_t size); + size_t + operator()(void* ptr, size_t size); }; -FileIOWriter::FileIOWriter(const std::string &fname) { +FileIOWriter::FileIOWriter(const std::string& fname) { name = fname; fs = std::fstream(name, std::ios::out | std::ios::binary); } @@ -89,79 +89,77 @@ FileIOWriter::~FileIOWriter() { } size_t -FileIOWriter::operator()(void *ptr, size_t size) { - fs.write(reinterpret_cast(ptr), size); +FileIOWriter::operator()(void* ptr, size_t size) { + fs.write(reinterpret_cast(ptr), size); } VecIndexPtr -GetVecIndexFactory(const IndexType &type, const Config &cfg) { - std::shared_ptr index; - auto gpu_device = -1; // TODO(linxj): remove hardcode here +GetVecIndexFactory(const IndexType& type, const Config& cfg) { + std::shared_ptr index; + auto gpu_device = -1; // TODO(linxj): remove hardcode here switch (type) { case IndexType::FAISS_IDMAP: { - index = std::make_shared(); + index = std::make_shared(); return std::make_shared(index); } case IndexType::FAISS_IVFFLAT_CPU: { - index = std::make_shared(); + index = std::make_shared(); break; } case IndexType::FAISS_IVFFLAT_GPU: { - index = std::make_shared(gpu_device); + index = std::make_shared(gpu_device); break; } case IndexType::FAISS_IVFFLAT_MIX: { - index = std::make_shared(gpu_device); + index = std::make_shared(gpu_device); return std::make_shared(index, IndexType::FAISS_IVFFLAT_MIX); } case IndexType::FAISS_IVFPQ_CPU: { - index = std::make_shared(); + index = std::make_shared(); break; } case IndexType::FAISS_IVFPQ_GPU: { - index = std::make_shared(gpu_device); + index = std::make_shared(gpu_device); break; } case IndexType::SPTAG_KDT_RNT_CPU: { - index = std::make_shared(); + index = std::make_shared(); break; } case IndexType::FAISS_IVFSQ8_MIX: { - index = std::make_shared(gpu_device); + index = std::make_shared(gpu_device); return std::make_shared(index, IndexType::FAISS_IVFSQ8_MIX); } case IndexType::FAISS_IVFSQ8_CPU: { - index = std::make_shared(); + index = std::make_shared(); break; } case IndexType::FAISS_IVFSQ8_GPU: { - index = std::make_shared(gpu_device); + index = std::make_shared(gpu_device); break; } case IndexType::FAISS_IVFSQ8_HYBRID: { - index = std::make_shared(gpu_device); + index = std::make_shared(gpu_device); break; } - case IndexType::NSG_MIX: { // TODO(linxj): bug. - index = std::make_shared(gpu_device); + case IndexType::NSG_MIX: { // TODO(linxj): bug. + index = std::make_shared(gpu_device); break; } - default: { - return nullptr; - } + default: { return nullptr; } } return std::make_shared(index, type); } VecIndexPtr -LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) { +LoadVecIndex(const IndexType& index_type, const knowhere::BinarySet& index_binary) { auto index = GetVecIndexFactory(index_type); index->Load(index_binary); return index; } VecIndexPtr -read_index(const std::string &location) { +read_index(const std::string& location) { knowhere::BinarySet load_data_list; FileIOReader reader(location); reader.fs.seekg(0, reader.fs.end); @@ -206,28 +204,28 @@ read_index(const std::string &location) { } Status -write_index(VecIndexPtr index, const std::string &location) { +write_index(VecIndexPtr index, const std::string& location) { try { auto binaryset = index->Serialize(); auto index_type = index->GetType(); FileIOWriter writer(location); writer(&index_type, sizeof(IndexType)); - for (auto &iter : binaryset.binary_map_) { + for (auto& iter : binaryset.binary_map_) { auto meta = iter.first.c_str(); size_t meta_length = iter.first.length(); writer(&meta_length, sizeof(meta_length)); - writer((void *) meta, meta_length); + writer((void*)meta, meta_length); auto binary = iter.second; int64_t binary_length = binary->size; writer(&binary_length, sizeof(binary_length)); - writer((void *) binary->data.get(), binary_length); + writer((void*)binary->data.get(), binary_length); } - } catch (knowhere::KnowhereException &e) { + } catch (knowhere::KnowhereException& e) { WRAPPER_LOG_ERROR << e.what(); return Status(KNOWHERE_UNEXPECTED_ERROR, e.what()); - } catch (std::exception &e) { + } catch (std::exception& e) { WRAPPER_LOG_ERROR << e.what(); std::string estring(e.what()); if (estring.find("No space left on device") != estring.npos) { @@ -241,7 +239,7 @@ write_index(VecIndexPtr index, const std::string &location) { } IndexType -ConvertToCpuIndexType(const IndexType &type) { +ConvertToCpuIndexType(const IndexType& type) { // TODO(linxj): add IDMAP switch (type) { case IndexType::FAISS_IVFFLAT_GPU: @@ -252,14 +250,12 @@ ConvertToCpuIndexType(const IndexType &type) { case IndexType::FAISS_IVFSQ8_MIX: { return IndexType::FAISS_IVFSQ8_CPU; } - default: { - return type; - } + default: { return type; } } } IndexType -ConvertToGpuIndexType(const IndexType &type) { +ConvertToGpuIndexType(const IndexType& type) { switch (type) { case IndexType::FAISS_IVFFLAT_MIX: case IndexType::FAISS_IVFFLAT_CPU: { @@ -269,12 +265,9 @@ ConvertToGpuIndexType(const IndexType &type) { case IndexType::FAISS_IVFSQ8_CPU: { return IndexType::FAISS_IVFSQ8_GPU; } - default: { - return type; - } + default: { return type; } } } -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/src/wrapper/VecIndex.h b/cpp/src/wrapper/VecIndex.h index e45f3d0561ede21382c208879f7c27e0d25fd63b..45631d71c1bc30adf501d4034a1af1644dca8165 100644 --- a/cpp/src/wrapper/VecIndex.h +++ b/cpp/src/wrapper/VecIndex.h @@ -15,36 +15,34 @@ // specific language governing permissions and limitations // under the License. - #pragma once -#include #include +#include -#include "utils/Status.h" -#include "knowhere/common/Config.h" #include "knowhere/common/BinarySet.h" +#include "knowhere/common/Config.h" #include "knowhere/index/vector_index/Quantizer.h" +#include "utils/Status.h" -namespace zilliz { namespace milvus { namespace engine { -using Config = zilliz::knowhere::Config; +using Config = knowhere::Config; enum class IndexType { INVALID = 0, FAISS_IDMAP = 1, FAISS_IVFFLAT_CPU, FAISS_IVFFLAT_GPU, - FAISS_IVFFLAT_MIX, // build on gpu and search on cpu + FAISS_IVFFLAT_MIX, // build on gpu and search on cpu FAISS_IVFPQ_CPU, FAISS_IVFPQ_GPU, SPTAG_KDT_RNT_CPU, FAISS_IVFSQ8_MIX, FAISS_IVFSQ8_CPU, FAISS_IVFSQ8_GPU, - FAISS_IVFSQ8_HYBRID, // only support build on gpu. + FAISS_IVFSQ8_HYBRID, // only support build on gpu. NSG_MIX, }; @@ -55,32 +53,20 @@ using VecIndexPtr = std::shared_ptr; class VecIndex { public: virtual Status - BuildAll(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg, - const int64_t &nt = 0, - const float *xt = nullptr) = 0; + BuildAll(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg, const int64_t& nt = 0, + const float* xt = nullptr) = 0; virtual Status - Add(const int64_t &nb, - const float *xb, - const int64_t *ids, - const Config &cfg = Config()) = 0; + Add(const int64_t& nb, const float* xb, const int64_t* ids, const Config& cfg = Config()) = 0; virtual Status - Search(const int64_t &nq, - const float *xq, - float *dist, - int64_t *ids, - const Config &cfg = Config()) = 0; + Search(const int64_t& nq, const float* xq, float* dist, int64_t* ids, const Config& cfg = Config()) = 0; virtual VecIndexPtr - CopyToGpu(const int64_t &device_id, - const Config &cfg = Config()) = 0; + CopyToGpu(const int64_t& device_id, const Config& cfg = Config()) = 0; virtual VecIndexPtr - CopyToCpu(const Config &cfg = Config()) = 0; + CopyToCpu(const Config& cfg = Config()) = 0; virtual VecIndexPtr Clone() = 0; @@ -97,46 +83,53 @@ class VecIndex { virtual int64_t Count() = 0; - virtual zilliz::knowhere::BinarySet + virtual knowhere::BinarySet Serialize() = 0; virtual Status - Load(const zilliz::knowhere::BinarySet &index_binary) = 0; + Load(const knowhere::BinarySet& index_binary) = 0; // TODO(linxj): refactor later //////////////// virtual knowhere::QuantizerPtr - LoadQuantizer(const Config& conf) { return nullptr; } + LoadQuantizer(const Config& conf) { + return nullptr; + } virtual Status - LoadData(const knowhere::QuantizerPtr &q, const Config &conf) { return Status::OK(); } + LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { + return Status::OK(); + } virtual Status - SetQuantizer(const knowhere::QuantizerPtr& q) { return Status::OK(); } + SetQuantizer(const knowhere::QuantizerPtr& q) { + return Status::OK(); + } virtual Status - UnsetQuantizer() { return Status::OK(); } + UnsetQuantizer() { + return Status::OK(); + } //////////////// }; extern Status -write_index(VecIndexPtr index, const std::string &location); +write_index(VecIndexPtr index, const std::string& location); extern VecIndexPtr -read_index(const std::string &location); +read_index(const std::string& location); extern VecIndexPtr -GetVecIndexFactory(const IndexType &type, const Config &cfg = Config()); +GetVecIndexFactory(const IndexType& type, const Config& cfg = Config()); extern VecIndexPtr -LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary); +LoadVecIndex(const IndexType& index_type, const knowhere::BinarySet& index_binary); extern IndexType -ConvertToCpuIndexType(const IndexType &type); +ConvertToCpuIndexType(const IndexType& type); extern IndexType -ConvertToGpuIndexType(const IndexType &type); +ConvertToGpuIndexType(const IndexType& type); -} // namespace engine -} // namespace milvus -} // namespace zilliz +} // namespace engine +} // namespace milvus diff --git a/cpp/unittest/CMakeLists.txt b/cpp/unittest/CMakeLists.txt index 0df8bfd79169d88009e7dbcf56521563c10feefd..6c9aeadcd123683abb6c846c6d815e6a4b1e647a 100644 --- a/cpp/unittest/CMakeLists.txt +++ b/cpp/unittest/CMakeLists.txt @@ -17,25 +17,25 @@ # under the License. #------------------------------------------------------------------------------- +include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include") + +foreach(dir ${CORE_INCLUDE_DIRS}) + include_directories(${dir}) +endforeach() + include_directories(${MILVUS_SOURCE_DIR}) include_directories(${MILVUS_ENGINE_SRC}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -aux_source_directory(${MILVUS_ENGINE_SRC}/cache cache_files) +link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") +aux_source_directory(${MILVUS_ENGINE_SRC}/cache cache_files) aux_source_directory(${MILVUS_ENGINE_SRC}/config config_files) - +aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db db_main_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/engine db_engine_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/insert db_insert_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler db_scheduler_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/context db_scheduler_context_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/scheduler/task db_scheduler_task_files) -set(db_scheduler_files - ${db_scheduler_files} - ${db_scheduler_context_files} - ${db_scheduler_task_files} - ) set(grpc_service_files ${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.grpc.pb.cc @@ -44,8 +44,6 @@ set(grpc_service_files ${MILVUS_ENGINE_SRC}/grpc/gen-status/status.pb.cc ) -aux_source_directory(${MILVUS_ENGINE_SRC}/metrics metrics_files) - aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler scheduler_main_files) aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/action scheduler_action_files) aux_source_directory(${MILVUS_ENGINE_SRC}/scheduler/event scheduler_event_files) @@ -63,12 +61,10 @@ set(scheduler_files aux_source_directory(${MILVUS_ENGINE_SRC}/server server_files) aux_source_directory(${MILVUS_ENGINE_SRC}/server/grpc_impl grpc_server_files) - aux_source_directory(${MILVUS_ENGINE_SRC}/utils utils_files) - aux_source_directory(${MILVUS_ENGINE_SRC}/wrapper wrapper_files) -set(unittest_files +set(entry_file ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp) set(helper_files @@ -87,7 +83,6 @@ set(common_files ${db_engine_files} ${db_insert_files} ${db_meta_files} - ${db_scheduler_files} ${metrics_files} ${scheduler_files} ${wrapper_files} @@ -118,10 +113,6 @@ set(unittest_libs cublas ) -foreach(dir ${CORE_INCLUDE_DIRS}) - include_directories(${dir}) -endforeach() - add_subdirectory(db) add_subdirectory(wrapper) add_subdirectory(metrics) diff --git a/cpp/unittest/db/CMakeLists.txt b/cpp/unittest/db/CMakeLists.txt index 91a51e484b46651931b101be2f6bbd54b5d19700..2cbf55a208d79b8a59be9c48a34f5908b5bd35b4 100644 --- a/cpp/unittest/db/CMakeLists.txt +++ b/cpp/unittest/db/CMakeLists.txt @@ -20,18 +20,14 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} test_files) -include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include") -link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") - - -set(db_test_files +cuda_add_executable(test_db ${common_files} ${test_files} - ) - -cuda_add_executable(db_test ${db_test_files}) + ) -target_link_libraries(db_test knowhere ${unittest_libs}) +target_link_libraries(test_db + knowhere + ${unittest_libs}) -install(TARGETS db_test DESTINATION unittest) +install(TARGETS test_db DESTINATION unittest) diff --git a/cpp/unittest/db/search_test.cpp b/cpp/unittest/db/search_test.cpp deleted file mode 100644 index 33917ea1b7f4d092c39c1934eecc80cbe3e25341..0000000000000000000000000000000000000000 --- a/cpp/unittest/db/search_test.cpp +++ /dev/null @@ -1,295 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include - -#include "scheduler/task/SearchTask.h" -#include "utils/TimeRecorder.h" - -using namespace zilliz::milvus; - -namespace { - -static constexpr uint64_t NQ = 15; -static constexpr uint64_t TOP_K = 64; - -void BuildResult(uint64_t nq, - uint64_t topk, - bool ascending, - std::vector &output_ids, - std::vector &output_distence) { - output_ids.clear(); - output_ids.resize(nq*topk); - output_distence.clear(); - output_distence.resize(nq*topk); - - for(uint64_t i = 0; i < nq; i++) { - for(uint64_t j = 0; j < topk; j++) { - output_ids[i * topk + j] = (long)(drand48()*100000); - output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48()); - } - } -} - -void CheckResult(const scheduler::Id2DistanceMap& src_1, - const scheduler::Id2DistanceMap& src_2, - const scheduler::Id2DistanceMap& target, - bool ascending) { - for(uint64_t i = 0; i < target.size() - 1; i++) { - if(ascending) { - ASSERT_LE(target[i].second, target[i + 1].second); - } else { - ASSERT_GE(target[i].second, target[i + 1].second); - } - } - - using ID2DistMap = std::map; - ID2DistMap src_map_1, src_map_2; - for(const auto& pair : src_1) { - src_map_1.insert(pair); - } - for(const auto& pair : src_2) { - src_map_2.insert(pair); - } - - for(const auto& pair : target) { - ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end()); - - float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first]; - ASSERT_LT(fabs(pair.second - dist), std::numeric_limits::epsilon()); - } -} - -void CheckCluster(const std::vector& target_ids, - const std::vector& target_distence, - const scheduler::ResultSet& src_result, - int64_t nq, - int64_t topk) { - ASSERT_EQ(src_result.size(), nq); - for(int64_t i = 0; i < nq; i++) { - auto& res = src_result[i]; - ASSERT_EQ(res.size(), topk); - - if(res.empty()) { - continue; - } - - ASSERT_EQ(res[0].first, target_ids[i*topk]); - ASSERT_EQ(res[topk - 1].first, target_ids[i*topk + topk - 1]); - } -} - -void CheckTopkResult(const scheduler::ResultSet& src_result, - bool ascending, - int64_t nq, - int64_t topk) { - ASSERT_EQ(src_result.size(), nq); - for(int64_t i = 0; i < nq; i++) { - auto& res = src_result[i]; - ASSERT_EQ(res.size(), topk); - - if(res.empty()) { - continue; - } - - for(int64_t k = 0; k < topk - 1; k++) { - if(ascending) { - ASSERT_LE(res[k].second, res[k + 1].second); - } else { - ASSERT_GE(res[k].second, res[k + 1].second); - } - } - } -} - -} - -TEST(DBSearchTest, TOPK_TEST) { - bool ascending = true; - std::vector target_ids; - std::vector target_distence; - scheduler::ResultSet src_result; - auto status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); - ASSERT_FALSE(status.ok()); - ASSERT_TRUE(src_result.empty()); - - BuildResult(NQ, TOP_K, ascending, target_ids, target_distence); - status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(src_result.size(), NQ); - - scheduler::ResultSet target_result; - status = scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - - status = scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result); - ASSERT_FALSE(status.ok()); - - status = scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - ASSERT_TRUE(src_result.empty()); - ASSERT_EQ(target_result.size(), NQ); - - std::vector src_ids; - std::vector src_distence; - uint64_t wrong_topk = TOP_K - 10; - BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence); - - status = scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result); - ASSERT_TRUE(status.ok()); - - status = scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - for(uint64_t i = 0; i < NQ; i++) { - ASSERT_EQ(target_result[i].size(), TOP_K); - } - - wrong_topk = TOP_K + 10; - BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence); - - status = scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - for(uint64_t i = 0; i < NQ; i++) { - ASSERT_EQ(target_result[i].size(), TOP_K); - } -} - -TEST(DBSearchTest, MERGE_TEST) { - bool ascending = true; - std::vector target_ids; - std::vector target_distence; - std::vector src_ids; - std::vector src_distence; - scheduler::ResultSet src_result, target_result; - - uint64_t src_count = 5, target_count = 8; - BuildResult(1, src_count, ascending, src_ids, src_distence); - BuildResult(1, target_count, ascending, target_ids, target_distence); - auto status = scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result); - ASSERT_TRUE(status.ok()); - status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result); - ASSERT_TRUE(status.ok()); - - { - scheduler::Id2DistanceMap src = src_result[0]; - scheduler::Id2DistanceMap target = target_result[0]; - status = scheduler::XSearchTask::MergeResult(src, target, 10, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), 10); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - scheduler::Id2DistanceMap src = src_result[0]; - scheduler::Id2DistanceMap target; - status = scheduler::XSearchTask::MergeResult(src, target, 10, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count); - ASSERT_TRUE(src.empty()); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - scheduler::Id2DistanceMap src = src_result[0]; - scheduler::Id2DistanceMap target = target_result[0]; - status = scheduler::XSearchTask::MergeResult(src, target, 30, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count + target_count); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - scheduler::Id2DistanceMap target = src_result[0]; - scheduler::Id2DistanceMap src = target_result[0]; - status = scheduler::XSearchTask::MergeResult(src, target, 30, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count + target_count); - CheckResult(src_result[0], target_result[0], target, ascending); - } -} - -TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) { - bool ascending = true; - std::vector target_ids; - std::vector target_distence; - scheduler::ResultSet src_result; - - auto DoCluster = [&](int64_t nq, int64_t topk) { - TimeRecorder rc("DoCluster"); - src_result.clear(); - BuildResult(nq, topk, ascending, target_ids, target_distence); - rc.RecordSection("build id/dietance map"); - - auto status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(src_result.size(), nq); - - rc.RecordSection("cluster result"); - - CheckCluster(target_ids, target_distence, src_result, nq, topk); - rc.RecordSection("check result"); - }; - - DoCluster(10000, 1000); - DoCluster(333, 999); - DoCluster(1, 1000); - DoCluster(1, 1); - DoCluster(7, 0); - DoCluster(9999, 1); - DoCluster(10001, 1); - DoCluster(58273, 1234); -} - -TEST(DBSearchTest, PARALLEL_TOPK_TEST) { - std::vector target_ids; - std::vector target_distence; - scheduler::ResultSet src_result; - - std::vector insufficient_ids; - std::vector insufficient_distence; - scheduler::ResultSet insufficient_result; - - auto DoTopk = [&](int64_t nq, int64_t topk,int64_t insufficient_topk, bool ascending) { - src_result.clear(); - insufficient_result.clear(); - - TimeRecorder rc("DoCluster"); - - BuildResult(nq, topk, ascending, target_ids, target_distence); - auto status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result); - rc.RecordSection("cluster result"); - - BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence); - status = scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result); - rc.RecordSection("cluster result"); - - scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result); - ASSERT_TRUE(status.ok()); - rc.RecordSection("topk"); - - CheckTopkResult(src_result, ascending, nq, topk); - rc.RecordSection("check result"); - }; - - DoTopk(5, 10, 4, false); - DoTopk(20005, 998, 123, true); -// DoTopk(9987, 12, 10, false); -// DoTopk(77777, 1000, 1, false); -// DoTopk(5432, 8899, 8899, true); -} \ No newline at end of file diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/test_db.cpp similarity index 59% rename from cpp/unittest/db/db_tests.cpp rename to cpp/unittest/db/test_db.cpp index f2f81215388723850b21946babcdbcedbe6e69a3..eb1f947d610cb9ebddfce6c36be19abcb78c5a8a 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/test_db.cpp @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -#include "utils.h" +#include "db/utils.h" #include "db/DB.h" #include "db/DBImpl.h" +#include "db/Constants.h" #include "db/meta/MetaConsts.h" #include "db/DBFactory.h" #include "cache/CpuCacheMgr.h" @@ -28,139 +29,142 @@ #include #include -using namespace zilliz::milvus; - namespace { - static const char* TABLE_NAME = "test_group"; - static constexpr int64_t TABLE_DIM = 256; - static constexpr int64_t VECTOR_COUNT = 25000; - static constexpr int64_t INSERT_LOOP = 1000; - static constexpr int64_t SECONDS_EACH_HOUR = 3600; - static constexpr int64_t DAY_SECONDS = 24 * 60 * 60; - - engine::meta::TableSchema BuildTableSchema() { - engine::meta::TableSchema table_info; - table_info.dimension_ = TABLE_DIM; - table_info.table_id_ = TABLE_NAME; - return table_info; - } +namespace ms = milvus; + +static const char *TABLE_NAME = "test_group"; +static constexpr int64_t TABLE_DIM = 256; +static constexpr int64_t VECTOR_COUNT = 25000; +static constexpr int64_t INSERT_LOOP = 1000; +static constexpr int64_t SECONDS_EACH_HOUR = 3600; +static constexpr int64_t DAY_SECONDS = 24 * 60 * 60; + +ms::engine::meta::TableSchema +BuildTableSchema() { + ms::engine::meta::TableSchema table_info; + table_info.dimension_ = TABLE_DIM; + table_info.table_id_ = TABLE_NAME; + return table_info; +} - void BuildVectors(int64_t n, std::vector& vectors) { - vectors.clear(); - vectors.resize(n*TABLE_DIM); - float* data = vectors.data(); - for(int i = 0; i < n; i++) { - for(int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48(); - data[TABLE_DIM * i] += i / 2000.; - } +void +BuildVectors(int64_t n, std::vector &vectors) { + vectors.clear(); + vectors.resize(n * TABLE_DIM); + float *data = vectors.data(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48(); + data[TABLE_DIM * i] += i / 2000.; } +} - std::string CurrentTmDate(int64_t offset_day = 0) { - time_t tt; - time( &tt ); - tt = tt + 8*SECONDS_EACH_HOUR; - tt = tt + 24*SECONDS_EACH_HOUR*offset_day; - tm* t= gmtime( &tt ); +std::string +CurrentTmDate(int64_t offset_day = 0) { + time_t tt; + time(&tt); + tt = tt + 8 * SECONDS_EACH_HOUR; + tt = tt + 24 * SECONDS_EACH_HOUR * offset_day; + tm t; + gmtime_r(&tt, &t); - std::string str = std::to_string(t->tm_year + 1900) + "-" + std::to_string(t->tm_mon + 1) - + "-" + std::to_string(t->tm_mday); + std::string str = std::to_string(t.tm_year + 1900) + "-" + std::to_string(t.tm_mon + 1) + + "-" + std::to_string(t.tm_mday); - return str; - } + return str; +} - void - ConvertTimeRangeToDBDates(const std::string &start_value, - const std::string &end_value, - std::vector &dates) { - dates.clear(); +void +ConvertTimeRangeToDBDates(const std::string &start_value, + const std::string &end_value, + std::vector &dates) { + dates.clear(); - time_t tt_start, tt_end; - tm tm_start, tm_end; - if (!zilliz::milvus::server::CommonUtil::TimeStrToTime(start_value, tt_start, tm_start)) { - return; - } + time_t tt_start, tt_end; + tm tm_start, tm_end; + if (!milvus::server::CommonUtil::TimeStrToTime(start_value, tt_start, tm_start)) { + return; + } - if (!zilliz::milvus::server::CommonUtil::TimeStrToTime(end_value, tt_end, tm_end)) { - return; - } + if (!milvus::server::CommonUtil::TimeStrToTime(end_value, tt_end, tm_end)) { + return; + } - long days = (tt_end > tt_start) ? (tt_end - tt_start) / DAY_SECONDS : (tt_start - tt_end) / - DAY_SECONDS; - if (days == 0) { - return; - } + int64_t days = (tt_end > tt_start) ? (tt_end - tt_start) / DAY_SECONDS : (tt_start - tt_end) / + DAY_SECONDS; + if (days == 0) { + return; + } - for (long i = 0; i < days; i++) { - time_t tt_day = tt_start + DAY_SECONDS * i; - tm tm_day; - zilliz::milvus::server::CommonUtil::ConvertTime(tt_day, tm_day); + for (int64_t i = 0; i < days; i++) { + time_t tt_day = tt_start + DAY_SECONDS * i; + tm tm_day; + milvus::server::CommonUtil::ConvertTime(tt_day, tm_day); - long date = tm_day.tm_year * 10000 + tm_day.tm_mon * 100 + - tm_day.tm_mday;//according to db logic - dates.push_back(date); - } + int64_t date = tm_day.tm_year * 10000 + tm_day.tm_mon * 100 + + tm_day.tm_mday;//according to db logic + dates.push_back(date); } - } +} // namespace + TEST_F(DBTest, CONFIG_TEST) { { - ASSERT_ANY_THROW(engine::ArchiveConf conf("wrong")); + ASSERT_ANY_THROW(ms::engine::ArchiveConf conf("wrong")); /* EXPECT_DEATH(engine::ArchiveConf conf("wrong"), ""); */ } { - engine::ArchiveConf conf("delete"); + ms::engine::ArchiveConf conf("delete"); ASSERT_EQ(conf.GetType(), "delete"); auto criterias = conf.GetCriterias(); - ASSERT_TRUE(criterias.size() == 0); + ASSERT_EQ(criterias.size(), 0); } { - engine::ArchiveConf conf("swap"); + ms::engine::ArchiveConf conf("swap"); ASSERT_EQ(conf.GetType(), "swap"); auto criterias = conf.GetCriterias(); - ASSERT_TRUE(criterias.size() == 0); + ASSERT_EQ(criterias.size(), 0); } { - ASSERT_ANY_THROW(engine::ArchiveConf conf1("swap", "disk:")); - ASSERT_ANY_THROW(engine::ArchiveConf conf2("swap", "disk:a")); - engine::ArchiveConf conf("swap", "disk:1024"); + ASSERT_ANY_THROW(ms::engine::ArchiveConf conf1("swap", "disk:")); + ASSERT_ANY_THROW(ms::engine::ArchiveConf conf2("swap", "disk:a")); + ms::engine::ArchiveConf conf("swap", "disk:1024"); auto criterias = conf.GetCriterias(); - ASSERT_TRUE(criterias.size() == 1); - ASSERT_TRUE(criterias["disk"] == 1024); + ASSERT_EQ(criterias.size(), 1); + ASSERT_EQ(criterias["disk"], 1024); } { - ASSERT_ANY_THROW(engine::ArchiveConf conf1("swap", "days:")); - ASSERT_ANY_THROW(engine::ArchiveConf conf2("swap", "days:a")); - engine::ArchiveConf conf("swap", "days:100"); + ASSERT_ANY_THROW(ms::engine::ArchiveConf conf1("swap", "days:")); + ASSERT_ANY_THROW(ms::engine::ArchiveConf conf2("swap", "days:a")); + ms::engine::ArchiveConf conf("swap", "days:100"); auto criterias = conf.GetCriterias(); - ASSERT_TRUE(criterias.size() == 1); - ASSERT_TRUE(criterias["days"] == 100); + ASSERT_EQ(criterias.size(), 1); + ASSERT_EQ(criterias["days"], 100); } { - ASSERT_ANY_THROW(engine::ArchiveConf conf1("swap", "days:")); - ASSERT_ANY_THROW(engine::ArchiveConf conf2("swap", "days:a")); - engine::ArchiveConf conf("swap", "days:100;disk:200"); + ASSERT_ANY_THROW(ms::engine::ArchiveConf conf1("swap", "days:")); + ASSERT_ANY_THROW(ms::engine::ArchiveConf conf2("swap", "days:a")); + ms::engine::ArchiveConf conf("swap", "days:100;disk:200"); auto criterias = conf.GetCriterias(); - ASSERT_TRUE(criterias.size() == 2); - ASSERT_TRUE(criterias["days"] == 100); - ASSERT_TRUE(criterias["disk"] == 200); + ASSERT_EQ(criterias.size(), 2); + ASSERT_EQ(criterias["days"], 100); + ASSERT_EQ(criterias["disk"], 200); } } - TEST_F(DBTest, DB_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); - engine::IDNumbers vector_ids; - engine::IDNumbers target_ids; + ms::engine::IDNumbers vector_ids; + ms::engine::IDNumbers target_ids; int64_t nb = 50; std::vector xb; @@ -171,7 +175,7 @@ TEST_F(DBTest, DB_TEST) { BuildVectors(qb, qxb); std::thread search([&]() { - engine::QueryResults results; + ms::engine::QueryResults results; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(2)); @@ -180,18 +184,18 @@ TEST_F(DBTest, DB_TEST) { uint64_t count = 0; uint64_t prev_count = 0; - for (auto j=0; j<10; ++j) { + for (auto j = 0; j < 10; ++j) { ss.str(""); db_->Size(count); prev_count = count; START_TIMER; stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); - ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; + ss << "Search " << j << " With Size " << count / ms::engine::M << " M"; STOP_TIMER(ss.str()); ASSERT_TRUE(stat.ok()); - for (auto k=0; kInsertVectors(TABLE_NAME, qb, qxb.data(), target_ids); ASSERT_EQ(target_ids.size(), qb); } else { @@ -222,14 +226,14 @@ TEST_F(DBTest, DB_TEST) { uint64_t count; stat = db_->GetTableRowCount(TABLE_NAME, count); ASSERT_TRUE(stat.ok()); - ASSERT_TRUE(count > 0); -}; + ASSERT_GT(count, 0); +} TEST_F(DBTest, SEARCH_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); @@ -239,63 +243,63 @@ TEST_F(DBTest, SEARCH_TEST) { size_t nb = VECTOR_COUNT; size_t nq = 10; size_t k = 5; - std::vector xb(nb*TABLE_DIM); - std::vector xq(nq*TABLE_DIM); - std::vector ids(nb); + std::vector xb(nb * TABLE_DIM); + std::vector xq(nq * TABLE_DIM); + std::vector ids(nb); std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution<> dis_xt(-1.0, 1.0); - for (size_t i = 0; i < nb*TABLE_DIM; i++) { + for (size_t i = 0; i < nb * TABLE_DIM; i++) { xb[i] = dis_xt(gen); - if (i < nb){ + if (i < nb) { ids[i] = i; } } - for (size_t i = 0; i < nq*TABLE_DIM; i++) { + for (size_t i = 0; i < nq * TABLE_DIM; i++) { xq[i] = dis_xt(gen); } // result data //std::vector nns_gt(k*nq); - std::vector nns(k*nq); // nns = nearst neg search + std::vector nns(k * nq); // nns = nearst neg search //std::vector dis_gt(k*nq); - std::vector dis(k*nq); + std::vector dis(k * nq); // insert data const int batch_size = 100; for (int j = 0; j < nb / batch_size; ++j) { - stat = db_->InsertVectors(TABLE_NAME, batch_size, xb.data()+batch_size*j*TABLE_DIM, ids); - if (j == 200){ sleep(1);} + stat = db_->InsertVectors(TABLE_NAME, batch_size, xb.data() + batch_size * j * TABLE_DIM, ids); + if (j == 200) { sleep(1); } ASSERT_TRUE(stat.ok()); } - engine::TableIndex index; - index.engine_type_ = (int)engine::EngineType::FAISS_IDMAP; + ms::engine::TableIndex index; + index.engine_type_ = (int) ms::engine::EngineType::FAISS_IDMAP; db_->CreateIndex(TABLE_NAME, index); // wait until build index finish { - engine::QueryResults results; + ms::engine::QueryResults results; stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); ASSERT_TRUE(stat.ok()); } {//search by specify index file - engine::meta::DatesT dates; + ms::engine::meta::DatesT dates; std::vector file_ids = {"1", "2", "3", "4", "5", "6"}; - engine::QueryResults results; + ms::engine::QueryResults results; stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results); ASSERT_TRUE(stat.ok()); } - // TODO(linxj): add groundTruth assert -}; + // TODO(lxj): add groundTruth assert +} TEST_F(DBTest, PRELOADTABLE_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); @@ -306,28 +310,27 @@ TEST_F(DBTest, PRELOADTABLE_TEST) { BuildVectors(nb, xb); int loop = 5; - for (auto i=0; iInsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); ASSERT_EQ(vector_ids.size(), nb); } - engine::TableIndex index; - index.engine_type_ = (int)engine::EngineType::FAISS_IDMAP; + ms::engine::TableIndex index; + index.engine_type_ = (int) ms::engine::EngineType::FAISS_IDMAP; db_->CreateIndex(TABLE_NAME, index); // wait until build index finish - int64_t prev_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); + int64_t prev_cache_usage = ms::cache::CpuCacheMgr::GetInstance()->CacheUsage(); stat = db_->PreloadTable(TABLE_NAME); ASSERT_TRUE(stat.ok()); - int64_t cur_cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); + int64_t cur_cache_usage = ms::cache::CpuCacheMgr::GetInstance()->CacheUsage(); ASSERT_TRUE(prev_cache_usage < cur_cache_usage); - } TEST_F(DBTest, SHUTDOWN_TEST) { db_->Stop(); - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); ASSERT_FALSE(stat.ok()); @@ -338,7 +341,7 @@ TEST_F(DBTest, SHUTDOWN_TEST) { stat = db_->HasTable(table_info.table_id_, has_table); ASSERT_FALSE(stat.ok()); - engine::IDNumbers ids; + ms::engine::IDNumbers ids; stat = db_->InsertVectors(table_info.table_id_, 0, nullptr, ids); ASSERT_FALSE(stat.ok()); @@ -349,15 +352,15 @@ TEST_F(DBTest, SHUTDOWN_TEST) { stat = db_->GetTableRowCount(table_info.table_id_, row_count); ASSERT_FALSE(stat.ok()); - engine::TableIndex index; + ms::engine::TableIndex index; stat = db_->CreateIndex(table_info.table_id_, index); ASSERT_FALSE(stat.ok()); stat = db_->DescribeIndex(table_info.table_id_, index); ASSERT_FALSE(stat.ok()); - engine::meta::DatesT dates; - engine::QueryResults results; + ms::engine::meta::DatesT dates; + ms::engine::QueryResults results; stat = db_->Query(table_info.table_id_, 1, 1, 1, nullptr, dates, results); ASSERT_FALSE(stat.ok()); std::vector file_ids; @@ -369,24 +372,24 @@ TEST_F(DBTest, SHUTDOWN_TEST) { } TEST_F(DBTest, INDEX_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); int64_t nb = VECTOR_COUNT; std::vector xb; BuildVectors(nb, xb); - engine::IDNumbers vector_ids; + ms::engine::IDNumbers vector_ids; db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); ASSERT_EQ(vector_ids.size(), nb); - engine::TableIndex index; - index.engine_type_ = (int)engine::EngineType::FAISS_IVFSQ8; - index.metric_type_ = (int)engine::MetricType::IP; + ms::engine::TableIndex index; + index.engine_type_ = (int) ms::engine::EngineType::FAISS_IVFSQ8; + index.metric_type_ = (int) ms::engine::MetricType::IP; stat = db_->CreateIndex(table_info.table_id_, index); ASSERT_TRUE(stat.ok()); - engine::TableIndex index_out; + ms::engine::TableIndex index_out; stat = db_->DescribeIndex(table_info.table_id_, index_out); ASSERT_TRUE(stat.ok()); ASSERT_EQ(index.engine_type_, index_out.engine_type_); @@ -398,23 +401,22 @@ TEST_F(DBTest, INDEX_TEST) { } TEST_F(DBTest2, ARHIVE_DISK_CHECK) { - - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - std::vector table_schema_array; + std::vector table_schema_array; stat = db_->AllTables(table_schema_array); ASSERT_TRUE(stat.ok()); bool bfound = false; - for(auto& schema : table_schema_array) { - if(schema.table_id_ == TABLE_NAME) { + for (auto &schema : table_schema_array) { + if (schema.table_id_ == TABLE_NAME) { bfound = true; break; } } ASSERT_TRUE(bfound); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); @@ -428,8 +430,8 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) { BuildVectors(nb, xb); int loop = INSERT_LOOP; - for (auto i=0; iInsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); std::this_thread::sleep_for(std::chrono::microseconds(1)); } @@ -438,14 +440,14 @@ TEST_F(DBTest2, ARHIVE_DISK_CHECK) { db_->Size(size); LOG(DEBUG) << "size=" << size; - ASSERT_LE(size, 1 * engine::meta::G); -}; + ASSERT_LE(size, 1 * ms::engine::G); +} TEST_F(DBTest2, DELETE_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); @@ -461,25 +463,25 @@ TEST_F(DBTest2, DELETE_TEST) { std::vector xb; BuildVectors(nb, xb); - engine::IDNumbers vector_ids; + ms::engine::IDNumbers vector_ids; stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); - engine::TableIndex index; + ms::engine::TableIndex index; stat = db_->CreateIndex(TABLE_NAME, index); - std::vector dates; + std::vector dates; stat = db_->DeleteTable(TABLE_NAME, dates); std::this_thread::sleep_for(std::chrono::seconds(2)); ASSERT_TRUE(stat.ok()); db_->HasTable(TABLE_NAME, has_table); ASSERT_FALSE(has_table); -}; +} TEST_F(DBTest2, DELETE_BY_RANGE_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); @@ -496,15 +498,15 @@ TEST_F(DBTest2, DELETE_BY_RANGE_TEST) { std::vector xb; BuildVectors(nb, xb); - engine::IDNumbers vector_ids; + ms::engine::IDNumbers vector_ids; stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); - engine::TableIndex index; + ms::engine::TableIndex index; stat = db_->CreateIndex(TABLE_NAME, index); db_->Size(size); ASSERT_NE(size, 0UL); - std::vector dates; + std::vector dates; std::string start_value = CurrentTmDate(); std::string end_value = CurrentTmDate(1); ConvertTimeRangeToDBDates(start_value, end_value, dates); @@ -515,4 +517,4 @@ TEST_F(DBTest2, DELETE_BY_RANGE_TEST) { uint64_t row_count = 0; db_->GetTableRowCount(TABLE_NAME, row_count); ASSERT_EQ(row_count, 0UL); -} \ No newline at end of file +} diff --git a/cpp/unittest/db/mysql_db_test.cpp b/cpp/unittest/db/test_db_mysql.cpp similarity index 70% rename from cpp/unittest/db/mysql_db_test.cpp rename to cpp/unittest/db/test_db_mysql.cpp index 53ce51ca45f5434a7eabf39206bafa8fe5bf64f1..3b73deb9cc5d81cd512834ec046cba69221909bd 100644 --- a/cpp/unittest/db/mysql_db_test.cpp +++ b/cpp/unittest/db/test_db_mysql.cpp @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -#include "utils.h" +#include "db/utils.h" #include "db/DB.h" #include "db/DBImpl.h" +#include "db/Constants.h" #include "db/meta/MetaConsts.h" #include @@ -26,48 +27,49 @@ #include #include -using namespace zilliz::milvus; - namespace { - static const char* TABLE_NAME = "test_group"; - static constexpr int64_t TABLE_DIM = 256; - static constexpr int64_t VECTOR_COUNT = 25000; - static constexpr int64_t INSERT_LOOP = 1000; - - engine::meta::TableSchema BuildTableSchema() { - engine::meta::TableSchema table_info; - table_info.dimension_ = TABLE_DIM; - table_info.table_id_ = TABLE_NAME; - table_info.engine_type_ = (int)engine::EngineType::FAISS_IDMAP; - return table_info; - } +namespace ms = milvus; - void BuildVectors(int64_t n, std::vector& vectors) { - vectors.clear(); - vectors.resize(n*TABLE_DIM); - float* data = vectors.data(); - for(int i = 0; i < n; i++) { - for(int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48(); - data[TABLE_DIM * i] += i / 2000.; - } - } +static const char *TABLE_NAME = "test_group"; +static constexpr int64_t TABLE_DIM = 256; +static constexpr int64_t VECTOR_COUNT = 25000; +static constexpr int64_t INSERT_LOOP = 1000; + +ms::engine::meta::TableSchema +BuildTableSchema() { + ms::engine::meta::TableSchema table_info; + table_info.dimension_ = TABLE_DIM; + table_info.table_id_ = TABLE_NAME; + table_info.engine_type_ = (int) ms::engine::EngineType::FAISS_IDMAP; + return table_info; +} +void +BuildVectors(int64_t n, std::vector &vectors) { + vectors.clear(); + vectors.resize(n * TABLE_DIM); + float *data = vectors.data(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48(); + data[TABLE_DIM * i] += i / 2000.; + } } +} // namespace TEST_F(MySqlDBTest, DB_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); - engine::IDNumbers vector_ids; - engine::IDNumbers target_ids; + ms::engine::IDNumbers vector_ids; + ms::engine::IDNumbers target_ids; int64_t nb = 50; std::vector xb; @@ -81,7 +83,7 @@ TEST_F(MySqlDBTest, DB_TEST) { ASSERT_EQ(target_ids.size(), qb); std::thread search([&]() { - engine::QueryResults results; + ms::engine::QueryResults results; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(5)); @@ -90,22 +92,22 @@ TEST_F(MySqlDBTest, DB_TEST) { uint64_t count = 0; uint64_t prev_count = 0; - for (auto j=0; j<10; ++j) { + for (auto j = 0; j < 10; ++j) { ss.str(""); db_->Size(count); prev_count = count; START_TIMER; stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); - ss << "Search " << j << " With Size " << count/engine::meta::M << " M"; + ss << "Search " << j << " With Size " << count / ms::engine::M << " M"; STOP_TIMER(ss.str()); ASSERT_TRUE(stat.ok()); - for (auto k=0; kInsertVectors(TABLE_NAME, qb, qxb.data(), target_ids); // ASSERT_EQ(target_ids.size(), qb); @@ -139,13 +141,13 @@ TEST_F(MySqlDBTest, DB_TEST) { } search.join(); -}; +} TEST_F(MySqlDBTest, SEARCH_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); @@ -155,68 +157,68 @@ TEST_F(MySqlDBTest, SEARCH_TEST) { size_t nb = VECTOR_COUNT; size_t nq = 10; size_t k = 5; - std::vector xb(nb*TABLE_DIM); - std::vector xq(nq*TABLE_DIM); - std::vector ids(nb); + std::vector xb(nb * TABLE_DIM); + std::vector xq(nq * TABLE_DIM); + std::vector ids(nb); std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution<> dis_xt(-1.0, 1.0); - for (size_t i = 0; i < nb*TABLE_DIM; i++) { + for (size_t i = 0; i < nb * TABLE_DIM; i++) { xb[i] = dis_xt(gen); - if (i < nb){ + if (i < nb) { ids[i] = i; } } - for (size_t i = 0; i < nq*TABLE_DIM; i++) { + for (size_t i = 0; i < nq * TABLE_DIM; i++) { xq[i] = dis_xt(gen); } // result data //std::vector nns_gt(k*nq); - std::vector nns(k*nq); // nns = nearst neg search + std::vector nns(k * nq); // nns = nearst neg search //std::vector dis_gt(k*nq); - std::vector dis(k*nq); + std::vector dis(k * nq); // insert data const int batch_size = 100; for (int j = 0; j < nb / batch_size; ++j) { - stat = db_->InsertVectors(TABLE_NAME, batch_size, xb.data()+batch_size*j*TABLE_DIM, ids); - if (j == 200){ sleep(1);} + stat = db_->InsertVectors(TABLE_NAME, batch_size, xb.data() + batch_size * j * TABLE_DIM, ids); + if (j == 200) { sleep(1); } ASSERT_TRUE(stat.ok()); } sleep(2); // wait until build index finish - engine::QueryResults results; + ms::engine::QueryResults results; stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); ASSERT_TRUE(stat.ok()); -}; +} TEST_F(MySqlDBTest, ARHIVE_DISK_CHECK) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - std::vector table_schema_array; + std::vector table_schema_array; stat = db_->AllTables(table_schema_array); ASSERT_TRUE(stat.ok()); bool bfound = false; - for(auto& schema : table_schema_array) { - if(schema.table_id_ == TABLE_NAME) { + for (auto &schema : table_schema_array) { + if (schema.table_id_ == TABLE_NAME) { bfound = true; break; } } ASSERT_TRUE(bfound); - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); - engine::IDNumbers vector_ids; - engine::IDNumbers target_ids; + ms::engine::IDNumbers vector_ids; + ms::engine::IDNumbers target_ids; uint64_t size; db_->Size(size); @@ -226,7 +228,7 @@ TEST_F(MySqlDBTest, ARHIVE_DISK_CHECK) { BuildVectors(nb, xb); int loop = INSERT_LOOP; - for (auto i=0; iInsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); std::this_thread::sleep_for(std::chrono::microseconds(1)); } @@ -235,15 +237,15 @@ TEST_F(MySqlDBTest, ARHIVE_DISK_CHECK) { db_->Size(size); LOG(DEBUG) << "size=" << size; - ASSERT_LE(size, 1 * engine::meta::G); -}; + ASSERT_LE(size, 1 * ms::engine::G); +} TEST_F(MySqlDBTest, DELETE_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); // std::cout << stat.ToString() << std::endl; - engine::meta::TableSchema table_info_get; + ms::engine::meta::TableSchema table_info_get; table_info_get.table_id_ = TABLE_NAME; stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); @@ -252,7 +254,7 @@ TEST_F(MySqlDBTest, DELETE_TEST) { db_->HasTable(TABLE_NAME, has_table); ASSERT_TRUE(has_table); - engine::IDNumbers vector_ids; + ms::engine::IDNumbers vector_ids; uint64_t size; db_->Size(size); @@ -262,7 +264,7 @@ TEST_F(MySqlDBTest, DELETE_TEST) { BuildVectors(nb, xb); int loop = 20; - for (auto i=0; iInsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); std::this_thread::sleep_for(std::chrono::microseconds(1)); } @@ -276,4 +278,5 @@ TEST_F(MySqlDBTest, DELETE_TEST) { // // db_->HasTable(TABLE_NAME, has_table); // ASSERT_FALSE(has_table); -}; +} + diff --git a/cpp/unittest/db/engine_test.cpp b/cpp/unittest/db/test_engine.cpp similarity index 65% rename from cpp/unittest/db/engine_test.cpp rename to cpp/unittest/db/test_engine.cpp index f55d8d9e27f6eeb51248578f779bec9357e39013..40afbb38f9c4e02cc47dbcc97d2044d978e0dab0 100644 --- a/cpp/unittest/db/engine_test.cpp +++ b/cpp/unittest/db/test_engine.cpp @@ -21,67 +21,66 @@ #include "db/engine/EngineFactory.h" #include "db/engine/ExecutionEngineImpl.h" -#include "utils.h" +#include "db/utils.h" -using namespace zilliz::milvus; +namespace { + +namespace ms = milvus; + +} TEST_F(EngineTest, FACTORY_TEST) { { - auto engine_ptr = engine::EngineFactory::Build( + auto engine_ptr = ms::engine::EngineFactory::Build( 512, "/tmp/milvus_index_1", - engine::EngineType::INVALID, - engine::MetricType::IP, - 1024 - ); + ms::engine::EngineType::INVALID, + ms::engine::MetricType::IP, + 1024); ASSERT_TRUE(engine_ptr == nullptr); } { - auto engine_ptr = engine::EngineFactory::Build( + auto engine_ptr = ms::engine::EngineFactory::Build( 512, "/tmp/milvus_index_1", - engine::EngineType::FAISS_IDMAP, - engine::MetricType::IP, - 1024 - ); + ms::engine::EngineType::FAISS_IDMAP, + ms::engine::MetricType::IP, + 1024); ASSERT_TRUE(engine_ptr != nullptr); } { - auto engine_ptr = engine::EngineFactory::Build( + auto engine_ptr = ms::engine::EngineFactory::Build( 512, "/tmp/milvus_index_1", - engine::EngineType::FAISS_IVFFLAT, - engine::MetricType::IP, - 1024 - ); + ms::engine::EngineType::FAISS_IVFFLAT, + ms::engine::MetricType::IP, + 1024); ASSERT_TRUE(engine_ptr != nullptr); } { - auto engine_ptr = engine::EngineFactory::Build( + auto engine_ptr = ms::engine::EngineFactory::Build( 512, "/tmp/milvus_index_1", - engine::EngineType::FAISS_IVFSQ8, - engine::MetricType::IP, - 1024 - ); + ms::engine::EngineType::FAISS_IVFSQ8, + ms::engine::MetricType::IP, + 1024); ASSERT_TRUE(engine_ptr != nullptr); } { - auto engine_ptr = engine::EngineFactory::Build( + auto engine_ptr = ms::engine::EngineFactory::Build( 512, "/tmp/milvus_index_1", - engine::EngineType::NSG_MIX, - engine::MetricType::IP, - 1024 - ); + ms::engine::EngineType::NSG_MIX, + ms::engine::MetricType::IP, + 1024); ASSERT_TRUE(engine_ptr != nullptr); } @@ -90,27 +89,26 @@ TEST_F(EngineTest, FACTORY_TEST) { TEST_F(EngineTest, ENGINE_IMPL_TEST) { uint16_t dimension = 64; std::string file_path = "/tmp/milvus_index_1"; - auto engine_ptr = engine::EngineFactory::Build( + auto engine_ptr = ms::engine::EngineFactory::Build( dimension, file_path, - engine::EngineType::FAISS_IVFFLAT, - engine::MetricType::IP, - 1024 - ); + ms::engine::EngineType::FAISS_IVFFLAT, + ms::engine::MetricType::IP, + 1024); std::vector data; - std::vector ids; + std::vector ids; const int row_count = 10000; data.reserve(row_count*dimension); ids.reserve(row_count); - for(long i = 0; i < row_count; i++) { + for (int64_t i = 0; i < row_count; i++) { ids.push_back(i); - for(uint16_t k = 0; k < dimension; k++) { + for (uint16_t k = 0; k < dimension; k++) { data.push_back(i*dimension + k); } } - auto status = engine_ptr->AddWithIds((long)ids.size(), data.data(), ids.data()); + auto status = engine_ptr->AddWithIds((int64_t)ids.size(), data.data(), ids.data()); ASSERT_TRUE(status.ok()); ASSERT_EQ(engine_ptr->Dimension(), dimension); @@ -127,5 +125,4 @@ TEST_F(EngineTest, ENGINE_IMPL_TEST) { // // auto engine_build = new_engine->BuildIndex("/tmp/milvus_index_2", engine::EngineType::FAISS_IVFSQ8); // //ASSERT_TRUE(status.ok()); - } diff --git a/cpp/unittest/db/mem_test.cpp b/cpp/unittest/db/test_mem.cpp similarity index 69% rename from cpp/unittest/db/mem_test.cpp rename to cpp/unittest/db/test_mem.cpp index b85718a86518ecb375b4e323ae8e56d7a99c7b25..1e465a69fbd8b3ffae2747860894da60ca5e1666 100644 --- a/cpp/unittest/db/mem_test.cpp +++ b/cpp/unittest/db/test_mem.cpp @@ -25,7 +25,7 @@ #include "db/engine/EngineFactory.h" #include "db/meta/MetaConsts.h" #include "metrics/Metrics.h" -#include "utils.h" +#include "db/utils.h" #include #include @@ -35,30 +35,32 @@ #include #include -using namespace zilliz::milvus; - namespace { -static std::string TABLE_NAME = "test_group"; +namespace ms = milvus; + static constexpr int64_t TABLE_DIM = 256; -std::string GenTableName() { +std::string +GetTableName() { auto now = std::chrono::system_clock::now(); auto micros = std::chrono::duration_cast( - now.time_since_epoch()).count(); - TABLE_NAME = std::to_string(micros); - return TABLE_NAME; + now.time_since_epoch()).count(); + static std::string table_name = std::to_string(micros); + return table_name; } -engine::meta::TableSchema BuildTableSchema() { - engine::meta::TableSchema table_info; +ms::engine::meta::TableSchema +BuildTableSchema() { + ms::engine::meta::TableSchema table_info; table_info.dimension_ = TABLE_DIM; - table_info.table_id_ = GenTableName(); - table_info.engine_type_ = (int) engine::EngineType::FAISS_IDMAP; + table_info.table_id_ = GetTableName(); + table_info.engine_type_ = (int) ms::engine::EngineType::FAISS_IDMAP; return table_info; } -void BuildVectors(int64_t n, std::vector &vectors) { +void +BuildVectors(int64_t n, std::vector &vectors) { vectors.clear(); vectors.resize(n * TABLE_DIM); float *data = vectors.data(); @@ -67,15 +69,15 @@ void BuildVectors(int64_t n, std::vector &vectors) { data[TABLE_DIM * i + j] = drand48(); } } -} +} // namespace TEST_F(MemManagerTest, VECTOR_SOURCE_TEST) { - engine::meta::TableSchema table_schema = BuildTableSchema(); + ms::engine::meta::TableSchema table_schema = BuildTableSchema(); auto status = impl_->CreateTable(table_schema); ASSERT_TRUE(status.ok()); - engine::meta::TableFileSchema table_file_schema; - table_file_schema.table_id_ = TABLE_NAME; + ms::engine::meta::TableFileSchema table_file_schema; + table_file_schema.table_id_ = GetTableName(); status = impl_->CreateTableFile(table_file_schema); ASSERT_TRUE(status.ok()); @@ -83,17 +85,17 @@ TEST_F(MemManagerTest, VECTOR_SOURCE_TEST) { std::vector vectors; BuildVectors(n, vectors); - engine::VectorSource source(n, vectors.data()); + ms::engine::VectorSource source(n, vectors.data()); size_t num_vectors_added; - engine::ExecutionEnginePtr execution_engine_ = - engine::EngineFactory::Build(table_file_schema.dimension_, - table_file_schema.location_, - (engine::EngineType) table_file_schema.engine_type_, - (engine::MetricType)table_file_schema.metric_type_, - table_schema.nlist_); - - engine::IDNumbers vector_ids; + ms::engine::ExecutionEnginePtr execution_engine_ = + ms::engine::EngineFactory::Build(table_file_schema.dimension_, + table_file_schema.location_, + (ms::engine::EngineType) table_file_schema.engine_type_, + (ms::engine::MetricType) table_file_schema.metric_type_, + table_schema.nlist_); + + ms::engine::IDNumbers vector_ids; status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added, vector_ids); ASSERT_TRUE(status.ok()); vector_ids = source.GetVectorIds(); @@ -113,19 +115,19 @@ TEST_F(MemManagerTest, VECTOR_SOURCE_TEST) { TEST_F(MemManagerTest, MEM_TABLE_FILE_TEST) { auto options = GetOptions(); - engine::meta::TableSchema table_schema = BuildTableSchema(); + ms::engine::meta::TableSchema table_schema = BuildTableSchema(); auto status = impl_->CreateTable(table_schema); ASSERT_TRUE(status.ok()); - engine::MemTableFile mem_table_file(TABLE_NAME, impl_, options); + ms::engine::MemTableFile mem_table_file(GetTableName(), impl_, options); int64_t n_100 = 100; std::vector vectors_100; BuildVectors(n_100, vectors_100); - engine::VectorSourcePtr source = std::make_shared(n_100, vectors_100.data()); + ms::engine::VectorSourcePtr source = std::make_shared(n_100, vectors_100.data()); - engine::IDNumbers vector_ids; + ms::engine::IDNumbers vector_ids; status = mem_table_file.Add(source, vector_ids); ASSERT_TRUE(status.ok()); @@ -137,11 +139,11 @@ TEST_F(MemManagerTest, MEM_TABLE_FILE_TEST) { size_t singleVectorMem = sizeof(float) * TABLE_DIM; ASSERT_EQ(mem_table_file.GetCurrentMem(), n_100 * singleVectorMem); - int64_t n_max = engine::MAX_TABLE_FILE_MEM / singleVectorMem; + int64_t n_max = ms::engine::MAX_TABLE_FILE_MEM / singleVectorMem; std::vector vectors_128M; BuildVectors(n_max, vectors_128M); - engine::VectorSourcePtr source_128M = std::make_shared(n_max, vectors_128M.data()); + ms::engine::VectorSourcePtr source_128M = std::make_shared(n_max, vectors_128M.data()); vector_ids.clear(); status = mem_table_file.Add(source_128M, vector_ids); @@ -154,7 +156,7 @@ TEST_F(MemManagerTest, MEM_TABLE_FILE_TEST) { TEST_F(MemManagerTest, MEM_TABLE_TEST) { auto options = GetOptions(); - engine::meta::TableSchema table_schema = BuildTableSchema(); + ms::engine::meta::TableSchema table_schema = BuildTableSchema(); auto status = impl_->CreateTable(table_schema); ASSERT_TRUE(status.ok()); @@ -162,27 +164,27 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) { std::vector vectors_100; BuildVectors(n_100, vectors_100); - engine::VectorSourcePtr source_100 = std::make_shared(n_100, vectors_100.data()); + ms::engine::VectorSourcePtr source_100 = std::make_shared(n_100, vectors_100.data()); - engine::MemTable mem_table(TABLE_NAME, impl_, options); + ms::engine::MemTable mem_table(GetTableName(), impl_, options); - engine::IDNumbers vector_ids; + ms::engine::IDNumbers vector_ids; status = mem_table.Add(source_100, vector_ids); ASSERT_TRUE(status.ok()); vector_ids = source_100->GetVectorIds(); ASSERT_EQ(vector_ids.size(), 100); - engine::MemTableFilePtr mem_table_file; + ms::engine::MemTableFilePtr mem_table_file; mem_table.GetCurrentMemTableFile(mem_table_file); size_t singleVectorMem = sizeof(float) * TABLE_DIM; ASSERT_EQ(mem_table_file->GetCurrentMem(), n_100 * singleVectorMem); - int64_t n_max = engine::MAX_TABLE_FILE_MEM / singleVectorMem; + int64_t n_max = ms::engine::MAX_TABLE_FILE_MEM / singleVectorMem; std::vector vectors_128M; BuildVectors(n_max, vectors_128M); vector_ids.clear(); - engine::VectorSourcePtr source_128M = std::make_shared(n_max, vectors_128M.data()); + ms::engine::VectorSourcePtr source_128M = std::make_shared(n_max, vectors_128M.data()); status = mem_table.Add(source_128M, vector_ids); ASSERT_TRUE(status.ok()); @@ -198,7 +200,7 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) { std::vector vectors_1G; BuildVectors(n_1G, vectors_1G); - engine::VectorSourcePtr source_1G = std::make_shared(n_1G, vectors_1G.data()); + ms::engine::VectorSourcePtr source_1G = std::make_shared(n_1G, vectors_1G.data()); vector_ids.clear(); status = mem_table.Add(source_1G, vector_ids); @@ -207,7 +209,7 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) { vector_ids = source_1G->GetVectorIds(); ASSERT_EQ(vector_ids.size(), n_1G); - int expectedTableFileCount = 2 + std::ceil((n_1G - n_100) * singleVectorMem / engine::MAX_TABLE_FILE_MEM); + int expectedTableFileCount = 2 + std::ceil((n_1G - n_100) * singleVectorMem / ms::engine::MAX_TABLE_FILE_MEM); ASSERT_EQ(mem_table.GetTableFileCount(), expectedTableFileCount); status = mem_table.Serialize(); @@ -215,11 +217,11 @@ TEST_F(MemManagerTest, MEM_TABLE_TEST) { } TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; - table_info_get.table_id_ = TABLE_NAME; + ms::engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = GetTableName(); stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); @@ -228,12 +230,12 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) { std::vector xb; BuildVectors(nb, xb); - engine::IDNumbers vector_ids; - for(int64_t i = 0; i < nb; i++) { + ms::engine::IDNumbers vector_ids; + for (int64_t i = 0; i < nb; i++) { vector_ids.push_back(i); } - stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + stat = db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); ASSERT_TRUE(stat.ok()); std::this_thread::sleep_for(std::chrono::seconds(3));//ensure raw data write to disk @@ -256,19 +258,19 @@ TEST_F(MemManagerTest2, SERIAL_INSERT_SEARCH_TEST) { int topk = 10, nprobe = 10; for (auto &pair : search_vectors) { auto &search = pair.second; - engine::QueryResults results; - stat = db_->Query(TABLE_NAME, topk, 1, nprobe, search.data(), results); + ms::engine::QueryResults results; + stat = db_->Query(GetTableName(), topk, 1, nprobe, search.data(), results); ASSERT_EQ(results[0][0].first, pair.first); ASSERT_LT(results[0][0].second, 1e-4); } } TEST_F(MemManagerTest2, INSERT_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; - table_info_get.table_id_ = TABLE_NAME; + ms::engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = GetTableName(); stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); @@ -280,8 +282,8 @@ TEST_F(MemManagerTest2, INSERT_TEST) { int64_t nb = 40960; std::vector xb; BuildVectors(nb, xb); - engine::IDNumbers vector_ids; - stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + ms::engine::IDNumbers vector_ids; + stat = db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); ASSERT_TRUE(stat.ok()); } auto end_time = METRICS_NOW_TIME; @@ -290,17 +292,17 @@ TEST_F(MemManagerTest2, INSERT_TEST) { } TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; - table_info_get.table_id_ = TABLE_NAME; + ms::engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = GetTableName(); stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); - engine::IDNumbers vector_ids; - engine::IDNumbers target_ids; + ms::engine::IDNumbers vector_ids; + ms::engine::IDNumbers target_ids; int64_t nb = 40960; std::vector xb; @@ -311,7 +313,7 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { BuildVectors(qb, qxb); std::thread search([&]() { - engine::QueryResults results; + ms::engine::QueryResults results; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(2)); @@ -326,8 +328,8 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { prev_count = count; START_TIMER; - stat = db_->Query(TABLE_NAME, k, qb, 10, qxb.data(), results); - ss << "Search " << j << " With Size " << count / engine::meta::M << " M"; + stat = db_->Query(GetTableName(), k, qb, 10, qxb.data(), results); + ss << "Search " << j << " With Size " << count / ms::engine::M << " M"; STOP_TIMER(ss.str()); ASSERT_TRUE(stat.ok()); @@ -349,29 +351,28 @@ TEST_F(MemManagerTest2, CONCURRENT_INSERT_SEARCH_TEST) { for (auto i = 0; i < loop; ++i) { if (i == 0) { - db_->InsertVectors(TABLE_NAME, qb, qxb.data(), target_ids); + db_->InsertVectors(GetTableName(), qb, qxb.data(), target_ids); ASSERT_EQ(target_ids.size(), qb); } else { - db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); } std::this_thread::sleep_for(std::chrono::microseconds(1)); } search.join(); -}; +} TEST_F(MemManagerTest2, VECTOR_IDS_TEST) { - engine::meta::TableSchema table_info = BuildTableSchema(); + ms::engine::meta::TableSchema table_info = BuildTableSchema(); auto stat = db_->CreateTable(table_info); - engine::meta::TableSchema table_info_get; - table_info_get.table_id_ = TABLE_NAME; + ms::engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = GetTableName(); stat = db_->DescribeTable(table_info_get); ASSERT_TRUE(stat.ok()); ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); - engine::IDNumbers vector_ids; - + ms::engine::IDNumbers vector_ids; int64_t nb = 100000; std::vector xb; @@ -382,7 +383,7 @@ TEST_F(MemManagerTest2, VECTOR_IDS_TEST) { vector_ids[i] = i; } - stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + stat = db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); ASSERT_EQ(vector_ids[0], 0); ASSERT_TRUE(stat.ok()); @@ -394,7 +395,7 @@ TEST_F(MemManagerTest2, VECTOR_IDS_TEST) { for (auto i = 0; i < nb; i++) { vector_ids[i] = i + nb; } - stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + stat = db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); ASSERT_EQ(vector_ids[0], nb); ASSERT_TRUE(stat.ok()); @@ -406,15 +407,15 @@ TEST_F(MemManagerTest2, VECTOR_IDS_TEST) { for (auto i = 0; i < nb; i++) { vector_ids[i] = i + nb / 2; } - stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); - ASSERT_EQ(vector_ids[0], nb/2); + stat = db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); + ASSERT_EQ(vector_ids[0], nb / 2); ASSERT_TRUE(stat.ok()); nb = 65536; //128M xb.clear(); BuildVectors(nb, xb); vector_ids.clear(); - stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + stat = db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); ASSERT_TRUE(stat.ok()); nb = 100; @@ -425,8 +426,9 @@ TEST_F(MemManagerTest2, VECTOR_IDS_TEST) { for (auto i = 0; i < nb; i++) { vector_ids[i] = i + nb; } - stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + stat = db_->InsertVectors(GetTableName(), nb, xb.data(), vector_ids); for (auto i = 0; i < nb; i++) { ASSERT_EQ(vector_ids[i], i + nb); } } + diff --git a/cpp/unittest/db/meta_tests.cpp b/cpp/unittest/db/test_meta.cpp similarity index 64% rename from cpp/unittest/db/meta_tests.cpp rename to cpp/unittest/db/test_meta.cpp index 23897408bc4cf0631b3ad065d184b62275ce77f2..d88c087aa47ea9859e32d267912190bc31db4349 100644 --- a/cpp/unittest/db/meta_tests.cpp +++ b/cpp/unittest/db/test_meta.cpp @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -#include "utils.h" +#include "db/utils.h" #include "db/meta/SqliteMetaImpl.h" #include "db/Utils.h" +#include "db/Constants.h" #include "db/meta/MetaConsts.h" #include @@ -25,13 +26,16 @@ #include #include -using namespace zilliz::milvus; -using namespace zilliz::milvus::engine; +namespace { + +namespace ms = milvus; + +} // namespace TEST_F(MetaTest, TABLE_TEST) { auto table_id = "meta_test_table"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl_->CreateTable(table); ASSERT_TRUE(status.ok()); @@ -49,7 +53,7 @@ TEST_F(MetaTest, TABLE_TEST) { table.table_id_ = table_id; status = impl_->CreateTable(table); - ASSERT_EQ(status.code(), DB_ALREADY_EXIST); + ASSERT_EQ(status.code(), ms::DB_ALREADY_EXIST); table.table_id_ = ""; status = impl_->CreateTable(table); @@ -59,16 +63,16 @@ TEST_F(MetaTest, TABLE_TEST) { TEST_F(MetaTest, TABLE_FILE_TEST) { auto table_id = "meta_test_table"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; table.dimension_ = 256; auto status = impl_->CreateTable(table); - meta::TableFileSchema table_file; + ms::engine::meta::TableFileSchema table_file; table_file.table_id_ = table.table_id_; status = impl_->CreateTableFile(table_file); ASSERT_TRUE(status.ok()); - ASSERT_EQ(table_file.file_type_, meta::TableFileSchema::NEW); + ASSERT_EQ(table_file.file_type_, ms::engine::meta::TableFileSchema::NEW); uint64_t cnt = 0; status = impl_->Count(table_id, cnt); @@ -77,30 +81,30 @@ TEST_F(MetaTest, TABLE_FILE_TEST) { auto file_id = table_file.file_id_; - auto new_file_type = meta::TableFileSchema::INDEX; + auto new_file_type = ms::engine::meta::TableFileSchema::INDEX; table_file.file_type_ = new_file_type; status = impl_->UpdateTableFile(table_file); ASSERT_TRUE(status.ok()); ASSERT_EQ(table_file.file_type_, new_file_type); - meta::DatesT dates; - dates.push_back(utils::GetDate()); + ms::engine::meta::DatesT dates; + dates.push_back(ms::engine::utils::GetDate()); status = impl_->DropPartitionsByDates(table_file.table_id_, dates); ASSERT_TRUE(status.ok()); dates.clear(); - for (auto i=2; i < 10; ++i) { - dates.push_back(utils::GetDateWithDelta(-1*i)); + for (auto i = 2; i < 10; ++i) { + dates.push_back(ms::engine::utils::GetDateWithDelta(-1 * i)); } status = impl_->DropPartitionsByDates(table_file.table_id_, dates); ASSERT_TRUE(status.ok()); - table_file.date_ = utils::GetDateWithDelta(-2); + table_file.date_ = ms::engine::utils::GetDateWithDelta(-2); status = impl_->UpdateTableFile(table_file); ASSERT_TRUE(status.ok()); - ASSERT_EQ(table_file.date_, utils::GetDateWithDelta(-2)); - ASSERT_FALSE(table_file.file_type_ == meta::TableFileSchema::TO_DELETE); + ASSERT_EQ(table_file.date_, ms::engine::utils::GetDateWithDelta(-2)); + ASSERT_FALSE(table_file.file_type_ == ms::engine::meta::TableFileSchema::TO_DELETE); dates.clear(); dates.push_back(table_file.date_); @@ -108,7 +112,7 @@ TEST_F(MetaTest, TABLE_FILE_TEST) { ASSERT_TRUE(status.ok()); std::vector ids = {table_file.id_}; - meta::TableFilesSchema files; + ms::engine::meta::TableFilesSchema files; status = impl_->GetTableFiles(table_file.table_id_, ids, files); ASSERT_TRUE(status.ok()); ASSERT_EQ(files.size(), 0UL); @@ -116,33 +120,34 @@ TEST_F(MetaTest, TABLE_FILE_TEST) { TEST_F(MetaTest, ARCHIVE_TEST_DAYS) { srand(time(0)); - DBMetaOptions options; + ms::engine::DBMetaOptions options; options.path_ = "/tmp/milvus_test"; - int days_num = rand() % 100; + unsigned int seed = 1; + int days_num = rand_r(&seed) % 100; std::stringstream ss; ss << "days:" << days_num; - options.archive_conf_ = ArchiveConf("delete", ss.str()); + options.archive_conf_ = ms::engine::ArchiveConf("delete", ss.str()); - meta::SqliteMetaImpl impl(options); + ms::engine::meta::SqliteMetaImpl impl(options); auto table_id = "meta_test_table"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl.CreateTable(table); - meta::TableFilesSchema files; - meta::TableFileSchema table_file; + ms::engine::meta::TableFilesSchema files; + ms::engine::meta::TableFileSchema table_file; table_file.table_id_ = table.table_id_; auto cnt = 100; - long ts = utils::GetMicroSecTimeStamp(); + int64_t ts = ms::engine::utils::GetMicroSecTimeStamp(); std::vector days; std::vector ids; - for (auto i=0; i ids; - for (auto i=0; i= 5) { - ASSERT_EQ(file.file_type_, meta::TableFileSchema::NEW); + ASSERT_EQ(file.file_type_, ms::engine::meta::TableFileSchema::NEW); } ++i; } @@ -214,7 +219,7 @@ TEST_F(MetaTest, ARCHIVE_TEST_DISK) { TEST_F(MetaTest, TABLE_FILES_TEST) { auto table_id = "meta_test_group"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl_->CreateTable(table); @@ -226,51 +231,51 @@ TEST_F(MetaTest, TABLE_FILES_TEST) { uint64_t to_index_files_cnt = 6; uint64_t index_files_cnt = 7; - meta::TableFileSchema table_file; + ms::engine::meta::TableFileSchema table_file; table_file.table_id_ = table.table_id_; - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::NEW_MERGE; + table_file.file_type_ = ms::engine::meta::TableFileSchema::NEW_MERGE; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::NEW_INDEX; + table_file.file_type_ = ms::engine::meta::TableFileSchema::NEW_INDEX; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::BACKUP; + table_file.file_type_ = ms::engine::meta::TableFileSchema::BACKUP; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::NEW; + table_file.file_type_ = ms::engine::meta::TableFileSchema::NEW; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::RAW; + table_file.file_type_ = ms::engine::meta::TableFileSchema::RAW; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::TO_INDEX; + table_file.file_type_ = ms::engine::meta::TableFileSchema::TO_INDEX; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::INDEX; + table_file.file_type_ = ms::engine::meta::TableFileSchema::INDEX; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } @@ -278,36 +283,36 @@ TEST_F(MetaTest, TABLE_FILES_TEST) { uint64_t total_row_count = 0; status = impl_->Count(table_id, total_row_count); ASSERT_TRUE(status.ok()); - ASSERT_EQ(total_row_count, raw_files_cnt+to_index_files_cnt+index_files_cnt); + ASSERT_EQ(total_row_count, raw_files_cnt + to_index_files_cnt + index_files_cnt); - meta::TableFilesSchema files; + ms::engine::meta::TableFilesSchema files; status = impl_->FilesToIndex(files); ASSERT_EQ(files.size(), to_index_files_cnt); - meta::DatePartionedTableFilesSchema dated_files; + ms::engine::meta::DatePartionedTableFilesSchema dated_files; status = impl_->FilesToMerge(table.table_id_, dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), raw_files_cnt); status = impl_->FilesToIndex(files); ASSERT_EQ(files.size(), to_index_files_cnt); - meta::DatesT dates = {table_file.date_}; + ms::engine::meta::DatesT dates = {table_file.date_}; std::vector ids; status = impl_->FilesToSearch(table_id, ids, dates, dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), - to_index_files_cnt+raw_files_cnt+index_files_cnt); + to_index_files_cnt + raw_files_cnt + index_files_cnt); - status = impl_->FilesToSearch(table_id, ids, meta::DatesT(), dated_files); + status = impl_->FilesToSearch(table_id, ids, ms::engine::meta::DatesT(), dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), - to_index_files_cnt+raw_files_cnt+index_files_cnt); + to_index_files_cnt + raw_files_cnt + index_files_cnt); - status = impl_->FilesToSearch(table_id, ids, meta::DatesT(), dated_files); + status = impl_->FilesToSearch(table_id, ids, ms::engine::meta::DatesT(), dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), - to_index_files_cnt+raw_files_cnt+index_files_cnt); + to_index_files_cnt + raw_files_cnt + index_files_cnt); ids.push_back(size_t(9999999999)); status = impl_->FilesToSearch(table_id, ids, dates, dated_files); - ASSERT_EQ(dated_files[table_file.date_].size(),0); + ASSERT_EQ(dated_files[table_file.date_].size(), 0); std::vector file_types; std::vector file_ids; @@ -316,26 +321,26 @@ TEST_F(MetaTest, TABLE_FILES_TEST) { ASSERT_FALSE(status.ok()); file_types = { - meta::TableFileSchema::NEW, - meta::TableFileSchema::NEW_MERGE, - meta::TableFileSchema::NEW_INDEX, - meta::TableFileSchema::TO_INDEX, - meta::TableFileSchema::INDEX, - meta::TableFileSchema::RAW, - meta::TableFileSchema::BACKUP, + ms::engine::meta::TableFileSchema::NEW, + ms::engine::meta::TableFileSchema::NEW_MERGE, + ms::engine::meta::TableFileSchema::NEW_INDEX, + ms::engine::meta::TableFileSchema::TO_INDEX, + ms::engine::meta::TableFileSchema::INDEX, + ms::engine::meta::TableFileSchema::RAW, + ms::engine::meta::TableFileSchema::BACKUP, }; status = impl_->FilesByType(table.table_id_, file_types, file_ids); ASSERT_TRUE(status.ok()); uint64_t total_cnt = new_index_files_cnt + new_merge_files_cnt + - backup_files_cnt + new_files_cnt + raw_files_cnt + - to_index_files_cnt + index_files_cnt; + backup_files_cnt + new_files_cnt + raw_files_cnt + + to_index_files_cnt + index_files_cnt; ASSERT_EQ(file_ids.size(), total_cnt); status = impl_->DeleteTableFiles(table_id); ASSERT_TRUE(status.ok()); status = impl_->CreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::NEW; + table_file.file_type_ = ms::engine::meta::TableFileSchema::NEW; status = impl_->UpdateTableFile(table_file); status = impl_->CleanUp(); ASSERT_TRUE(status.ok()); @@ -350,11 +355,11 @@ TEST_F(MetaTest, TABLE_FILES_TEST) { TEST_F(MetaTest, INDEX_TEST) { auto table_id = "index_test"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl_->CreateTable(table); - TableIndex index; + ms::engine::TableIndex index; index.metric_type_ = 2; index.nlist_ = 1234; index.engine_type_ = 3; @@ -365,12 +370,12 @@ TEST_F(MetaTest, INDEX_TEST) { status = impl_->UpdateTableFlag(table_id, flag); ASSERT_TRUE(status.ok()); - engine::meta::TableSchema table_info; + ms::engine::meta::TableSchema table_info; table_info.table_id_ = table_id; status = impl_->DescribeTable(table_info); ASSERT_EQ(table_info.flag_, flag); - TableIndex index_out; + ms::engine::TableIndex index_out; status = impl_->DescribeTableIndex(table_id, index_out); ASSERT_EQ(index_out.metric_type_, index.metric_type_); ASSERT_EQ(index_out.nlist_, index.nlist_); @@ -385,4 +390,4 @@ TEST_F(MetaTest, INDEX_TEST) { status = impl_->UpdateTableFilesToIndex(table_id); ASSERT_TRUE(status.ok()); -} \ No newline at end of file +} diff --git a/cpp/unittest/db/mysql_meta_test.cpp b/cpp/unittest/db/test_meta_mysql.cpp similarity index 64% rename from cpp/unittest/db/mysql_meta_test.cpp rename to cpp/unittest/db/test_meta_mysql.cpp index 46b030235eeb1144b4b7e44870161343b10cf6ef..a825e8cbdd3fb724f0e56c443fe3a1a33ff980ed 100644 --- a/cpp/unittest/db/mysql_meta_test.cpp +++ b/cpp/unittest/db/test_meta_mysql.cpp @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "utils.h" +#include "db/utils.h" #include "db/meta/MySQLMetaImpl.h" #include "db/Utils.h" #include "db/meta/MetaConsts.h" @@ -27,13 +27,17 @@ #include #include -using namespace zilliz::milvus; -using namespace zilliz::milvus::engine; +namespace { + +namespace ms = milvus; + +} + TEST_F(MySqlMetaTest, TABLE_TEST) { auto table_id = "meta_test_table"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl_->CreateTable(table); ASSERT_TRUE(status.ok()); @@ -51,7 +55,7 @@ TEST_F(MySqlMetaTest, TABLE_TEST) { table.table_id_ = table_id; status = impl_->CreateTable(table); - ASSERT_EQ(status.code(), DB_ALREADY_EXIST); + ASSERT_EQ(status.code(), ms::DB_ALREADY_EXIST); table.table_id_ = ""; status = impl_->CreateTable(table); @@ -64,20 +68,19 @@ TEST_F(MySqlMetaTest, TABLE_TEST) { TEST_F(MySqlMetaTest, TABLE_FILE_TEST) { auto table_id = "meta_test_table"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; table.dimension_ = 256; auto status = impl_->CreateTable(table); - - meta::TableFileSchema table_file; + ms::engine::meta::TableFileSchema table_file; table_file.table_id_ = table.table_id_; status = impl_->CreateTableFile(table_file); ASSERT_TRUE(status.ok()); - ASSERT_EQ(table_file.file_type_, meta::TableFileSchema::NEW); + ASSERT_EQ(table_file.file_type_, ms::engine::meta::TableFileSchema::NEW); - meta::DatesT dates; - dates.push_back(utils::GetDate()); + ms::engine::meta::DatesT dates; + dates.push_back(ms::engine::utils::GetDate()); status = impl_->DropPartitionsByDates(table_file.table_id_, dates); ASSERT_TRUE(status.ok()); @@ -88,7 +91,7 @@ TEST_F(MySqlMetaTest, TABLE_FILE_TEST) { auto file_id = table_file.file_id_; - auto new_file_type = meta::TableFileSchema::INDEX; + auto new_file_type = ms::engine::meta::TableFileSchema::INDEX; table_file.file_type_ = new_file_type; status = impl_->UpdateTableFile(table_file); @@ -96,17 +99,17 @@ TEST_F(MySqlMetaTest, TABLE_FILE_TEST) { ASSERT_EQ(table_file.file_type_, new_file_type); dates.clear(); - for (auto i=2; i < 10; ++i) { - dates.push_back(utils::GetDateWithDelta(-1*i)); + for (auto i = 2; i < 10; ++i) { + dates.push_back(ms::engine::utils::GetDateWithDelta(-1 * i)); } status = impl_->DropPartitionsByDates(table_file.table_id_, dates); ASSERT_TRUE(status.ok()); - table_file.date_ = utils::GetDateWithDelta(-2); + table_file.date_ = ms::engine::utils::GetDateWithDelta(-2); status = impl_->UpdateTableFile(table_file); ASSERT_TRUE(status.ok()); - ASSERT_EQ(table_file.date_, utils::GetDateWithDelta(-2)); - ASSERT_FALSE(table_file.file_type_ == meta::TableFileSchema::TO_DELETE); + ASSERT_EQ(table_file.date_, ms::engine::utils::GetDateWithDelta(-2)); + ASSERT_FALSE(table_file.file_type_ == ms::engine::meta::TableFileSchema::TO_DELETE); dates.clear(); dates.push_back(table_file.date_); @@ -114,41 +117,42 @@ TEST_F(MySqlMetaTest, TABLE_FILE_TEST) { ASSERT_TRUE(status.ok()); std::vector ids = {table_file.id_}; - meta::TableFilesSchema files; + ms::engine::meta::TableFilesSchema files; status = impl_->GetTableFiles(table_file.table_id_, ids, files); ASSERT_EQ(files.size(), 0UL); } TEST_F(MySqlMetaTest, ARCHIVE_TEST_DAYS) { srand(time(0)); - DBMetaOptions options = GetOptions().meta_; + ms::engine::DBMetaOptions options = GetOptions().meta_; - int days_num = rand() % 100; + unsigned int seed = 1; + int days_num = rand_r(&seed) % 100; std::stringstream ss; ss << "days:" << days_num; - options.archive_conf_ = ArchiveConf("delete", ss.str()); - int mode = DBOptions::MODE::SINGLE; - meta::MySQLMetaImpl impl(options, mode); + options.archive_conf_ = ms::engine::ArchiveConf("delete", ss.str()); + int mode = ms::engine::DBOptions::MODE::SINGLE; + ms::engine::meta::MySQLMetaImpl impl(options, mode); auto table_id = "meta_test_table"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl.CreateTable(table); - meta::TableFilesSchema files; - meta::TableFileSchema table_file; + ms::engine::meta::TableFilesSchema files; + ms::engine::meta::TableFileSchema table_file; table_file.table_id_ = table.table_id_; auto cnt = 100; - long ts = utils::GetMicroSecTimeStamp(); + int64_t ts = ms::engine::utils::GetMicroSecTimeStamp(); std::vector days; std::vector ids; - for (auto i=0; i file_types = { - (int) meta::TableFileSchema::NEW, + (int) ms::engine::meta::TableFileSchema::NEW, }; std::vector file_ids; status = impl.FilesByType(table_id, file_types, file_ids); @@ -184,32 +188,32 @@ TEST_F(MySqlMetaTest, ARCHIVE_TEST_DAYS) { } TEST_F(MySqlMetaTest, ARCHIVE_TEST_DISK) { - DBMetaOptions options = GetOptions().meta_; + ms::engine::DBMetaOptions options = GetOptions().meta_; - options.archive_conf_ = ArchiveConf("delete", "disk:11"); - int mode = DBOptions::MODE::SINGLE; - auto impl = meta::MySQLMetaImpl(options, mode); + options.archive_conf_ = ms::engine::ArchiveConf("delete", "disk:11"); + int mode = ms::engine::DBOptions::MODE::SINGLE; + auto impl = ms::engine::meta::MySQLMetaImpl(options, mode); auto table_id = "meta_test_group"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl.CreateTable(table); - meta::TableSchema table_schema; + ms::engine::meta::TableSchema table_schema; table_schema.table_id_ = ""; status = impl.CreateTable(table_schema); - meta::TableFilesSchema files; - meta::TableFileSchema table_file; + ms::engine::meta::TableFilesSchema files; + ms::engine::meta::TableFileSchema table_file; table_file.table_id_ = table.table_id_; auto cnt = 10; auto each_size = 2UL; std::vector ids; - for (auto i=0; i= 5) { - ASSERT_EQ(file.file_type_, meta::TableFileSchema::NEW); + ASSERT_EQ(file.file_type_, ms::engine::meta::TableFileSchema::NEW); } ++i; } @@ -236,7 +240,7 @@ TEST_F(MySqlMetaTest, ARCHIVE_TEST_DISK) { TEST_F(MySqlMetaTest, TABLE_FILES_TEST) { auto table_id = "meta_test_group"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl_->CreateTable(table); @@ -248,51 +252,51 @@ TEST_F(MySqlMetaTest, TABLE_FILES_TEST) { uint64_t to_index_files_cnt = 6; uint64_t index_files_cnt = 7; - meta::TableFileSchema table_file; + ms::engine::meta::TableFileSchema table_file; table_file.table_id_ = table.table_id_; - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::NEW_MERGE; + table_file.file_type_ = ms::engine::meta::TableFileSchema::NEW_MERGE; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::NEW_INDEX; + table_file.file_type_ = ms::engine::meta::TableFileSchema::NEW_INDEX; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::BACKUP; + table_file.file_type_ = ms::engine::meta::TableFileSchema::BACKUP; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::NEW; + table_file.file_type_ = ms::engine::meta::TableFileSchema::NEW; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::RAW; + table_file.file_type_ = ms::engine::meta::TableFileSchema::RAW; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::TO_INDEX; + table_file.file_type_ = ms::engine::meta::TableFileSchema::TO_INDEX; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } - for (auto i=0; iCreateTableFile(table_file); - table_file.file_type_ = meta::TableFileSchema::INDEX; + table_file.file_type_ = ms::engine::meta::TableFileSchema::INDEX; table_file.row_count_ = 1; status = impl_->UpdateTableFile(table_file); } @@ -300,36 +304,36 @@ TEST_F(MySqlMetaTest, TABLE_FILES_TEST) { uint64_t total_row_count = 0; status = impl_->Count(table_id, total_row_count); ASSERT_TRUE(status.ok()); - ASSERT_EQ(total_row_count, raw_files_cnt+to_index_files_cnt+index_files_cnt); + ASSERT_EQ(total_row_count, raw_files_cnt + to_index_files_cnt + index_files_cnt); - meta::TableFilesSchema files; + ms::engine::meta::TableFilesSchema files; status = impl_->FilesToIndex(files); ASSERT_EQ(files.size(), to_index_files_cnt); - meta::DatePartionedTableFilesSchema dated_files; + ms::engine::meta::DatePartionedTableFilesSchema dated_files; status = impl_->FilesToMerge(table.table_id_, dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), raw_files_cnt); status = impl_->FilesToIndex(files); ASSERT_EQ(files.size(), to_index_files_cnt); - meta::DatesT dates = {table_file.date_}; + ms::engine::meta::DatesT dates = {table_file.date_}; std::vector ids; status = impl_->FilesToSearch(table_id, ids, dates, dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), - to_index_files_cnt+raw_files_cnt+index_files_cnt); + to_index_files_cnt + raw_files_cnt + index_files_cnt); - status = impl_->FilesToSearch(table_id, ids, meta::DatesT(), dated_files); + status = impl_->FilesToSearch(table_id, ids, ms::engine::meta::DatesT(), dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), - to_index_files_cnt+raw_files_cnt+index_files_cnt); + to_index_files_cnt + raw_files_cnt + index_files_cnt); - status = impl_->FilesToSearch(table_id, ids, meta::DatesT(), dated_files); + status = impl_->FilesToSearch(table_id, ids, ms::engine::meta::DatesT(), dated_files); ASSERT_EQ(dated_files[table_file.date_].size(), - to_index_files_cnt+raw_files_cnt+index_files_cnt); + to_index_files_cnt + raw_files_cnt + index_files_cnt); ids.push_back(size_t(9999999999)); status = impl_->FilesToSearch(table_id, ids, dates, dated_files); - ASSERT_EQ(dated_files[table_file.date_].size(),0); + ASSERT_EQ(dated_files[table_file.date_].size(), 0); std::vector file_types; std::vector file_ids; @@ -338,19 +342,19 @@ TEST_F(MySqlMetaTest, TABLE_FILES_TEST) { ASSERT_FALSE(status.ok()); file_types = { - meta::TableFileSchema::NEW, - meta::TableFileSchema::NEW_MERGE, - meta::TableFileSchema::NEW_INDEX, - meta::TableFileSchema::TO_INDEX, - meta::TableFileSchema::INDEX, - meta::TableFileSchema::RAW, - meta::TableFileSchema::BACKUP, + ms::engine::meta::TableFileSchema::NEW, + ms::engine::meta::TableFileSchema::NEW_MERGE, + ms::engine::meta::TableFileSchema::NEW_INDEX, + ms::engine::meta::TableFileSchema::TO_INDEX, + ms::engine::meta::TableFileSchema::INDEX, + ms::engine::meta::TableFileSchema::RAW, + ms::engine::meta::TableFileSchema::BACKUP, }; status = impl_->FilesByType(table.table_id_, file_types, file_ids); ASSERT_TRUE(status.ok()); uint64_t total_cnt = new_index_files_cnt + new_merge_files_cnt + - backup_files_cnt + new_files_cnt + raw_files_cnt + - to_index_files_cnt + index_files_cnt; + backup_files_cnt + new_files_cnt + raw_files_cnt + + to_index_files_cnt + index_files_cnt; ASSERT_EQ(file_ids.size(), total_cnt); status = impl_->DeleteTableFiles(table_id); @@ -366,11 +370,11 @@ TEST_F(MySqlMetaTest, TABLE_FILES_TEST) { TEST_F(MySqlMetaTest, INDEX_TEST) { auto table_id = "index_test"; - meta::TableSchema table; + ms::engine::meta::TableSchema table; table.table_id_ = table_id; auto status = impl_->CreateTable(table); - TableIndex index; + ms::engine::TableIndex index; index.metric_type_ = 2; index.nlist_ = 1234; index.engine_type_ = 3; @@ -381,12 +385,12 @@ TEST_F(MySqlMetaTest, INDEX_TEST) { status = impl_->UpdateTableFlag(table_id, flag); ASSERT_TRUE(status.ok()); - engine::meta::TableSchema table_info; + ms::engine::meta::TableSchema table_info; table_info.table_id_ = table_id; status = impl_->DescribeTable(table_info); ASSERT_EQ(table_info.flag_, flag); - TableIndex index_out; + ms::engine::TableIndex index_out; status = impl_->DescribeTableIndex(table_id, index_out); ASSERT_EQ(index_out.metric_type_, index.metric_type_); ASSERT_EQ(index_out.nlist_, index.nlist_); diff --git a/cpp/unittest/db/misc_test.cpp b/cpp/unittest/db/test_misc.cpp similarity index 69% rename from cpp/unittest/db/misc_test.cpp rename to cpp/unittest/db/test_misc.cpp index 762e9944ee8cc447287eb9807307cff125a27d05..18cca45b4b913b1b5e4364cbe3aceeb35e4e736f 100644 --- a/cpp/unittest/db/misc_test.cpp +++ b/cpp/unittest/db/test_misc.cpp @@ -27,43 +27,47 @@ #include #include -using namespace zilliz::milvus; +namespace { + +namespace ms = milvus; + +} // namespace TEST(DBMiscTest, EXCEPTION_TEST) { - Exception ex1(100, "error"); + ms::Exception ex1(100, "error"); std::string what = ex1.what(); ASSERT_EQ(what, "error"); ASSERT_EQ(ex1.code(), 100); - InvalidArgumentException ex2; - ASSERT_EQ(ex2.code(), SERVER_INVALID_ARGUMENT); + ms::InvalidArgumentException ex2; + ASSERT_EQ(ex2.code(), ms::SERVER_INVALID_ARGUMENT); } TEST(DBMiscTest, OPTIONS_TEST) { try { - engine::ArchiveConf archive("$$##"); - } catch (std::exception& ex) { + ms::engine::ArchiveConf archive("$$##"); + } catch (std::exception &ex) { ASSERT_TRUE(true); } { - engine::ArchiveConf archive("delete", "no"); + ms::engine::ArchiveConf archive("delete", "no"); ASSERT_TRUE(archive.GetCriterias().empty()); } { - engine::ArchiveConf archive("delete", "1:2"); + ms::engine::ArchiveConf archive("delete", "1:2"); ASSERT_TRUE(archive.GetCriterias().empty()); } { - engine::ArchiveConf archive("delete", "1:2:3"); + ms::engine::ArchiveConf archive("delete", "1:2:3"); ASSERT_TRUE(archive.GetCriterias().empty()); } { - engine::ArchiveConf archive("delete"); - engine::ArchiveConf::CriteriaT criterial = { + ms::engine::ArchiveConf archive("delete"); + ms::engine::ArchiveConf::CriteriaT criterial = { {"disk", 1024}, {"days", 100} }; @@ -76,29 +80,29 @@ TEST(DBMiscTest, OPTIONS_TEST) { } TEST(DBMiscTest, META_TEST) { - engine::DBMetaOptions options; + ms::engine::DBMetaOptions options; options.path_ = "/tmp/milvus_test"; - engine::meta::SqliteMetaImpl impl(options); + ms::engine::meta::SqliteMetaImpl impl(options); time_t tt; - time( &tt ); + time(&tt); int delta = 10; - engine::meta::DateT dt = engine::utils::GetDate(tt, delta); + ms::engine::meta::DateT dt = ms::engine::utils::GetDate(tt, delta); ASSERT_GT(dt, 0); } TEST(DBMiscTest, UTILS_TEST) { - engine::DBMetaOptions options; + ms::engine::DBMetaOptions options; options.path_ = "/tmp/milvus_test/main"; options.slave_paths_.push_back("/tmp/milvus_test/slave_1"); options.slave_paths_.push_back("/tmp/milvus_test/slave_2"); const std::string TABLE_NAME = "test_tbl"; - auto status = engine::utils::CreateTablePath(options, TABLE_NAME); + auto status = ms::engine::utils::CreateTablePath(options, TABLE_NAME); ASSERT_TRUE(status.ok()); ASSERT_TRUE(boost::filesystem::exists(options.path_)); - for(auto& path : options.slave_paths_) { - ASSERT_TRUE(boost::filesystem::exists(path)); + for (auto &path : options.slave_paths_) { + ASSERT_TRUE(boost::filesystem::exists(path)); } // options.slave_paths.push_back("/"); @@ -109,18 +113,18 @@ TEST(DBMiscTest, UTILS_TEST) { // status = engine::utils::CreateTablePath(options, TABLE_NAME); // ASSERT_FALSE(status.ok()); - engine::meta::TableFileSchema file; + ms::engine::meta::TableFileSchema file; file.id_ = 50; file.table_id_ = TABLE_NAME; file.file_type_ = 3; file.date_ = 155000; - status = engine::utils::GetTableFilePath(options, file); + status = ms::engine::utils::GetTableFilePath(options, file); ASSERT_FALSE(status.ok()); ASSERT_TRUE(file.location_.empty()); - status = engine::utils::DeleteTablePath(options, TABLE_NAME); + status = ms::engine::utils::DeleteTablePath(options, TABLE_NAME); ASSERT_TRUE(status.ok()); - status = engine::utils::DeleteTableFilePath(options, file); + status = ms::engine::utils::DeleteTableFilePath(options, file); ASSERT_TRUE(status.ok()); -} \ No newline at end of file +} diff --git a/cpp/unittest/db/test_search.cpp b/cpp/unittest/db/test_search.cpp new file mode 100644 index 0000000000000000000000000000000000000000..348463357eb01d48e22f05b0f40a3c3e0acf6bd5 --- /dev/null +++ b/cpp/unittest/db/test_search.cpp @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "scheduler/task/SearchTask.h" +#include "utils/TimeRecorder.h" + +namespace { + +namespace ms = milvus::scheduler; + +void +BuildResult(uint64_t nq, + uint64_t topk, + bool ascending, + std::vector &output_ids, + std::vector &output_distence) { + output_ids.clear(); + output_ids.resize(nq * topk); + output_distence.clear(); + output_distence.resize(nq * topk); + + for (uint64_t i = 0; i < nq; i++) { + for (uint64_t j = 0; j < topk; j++) { + output_ids[i * topk + j] = (int64_t) (drand48() * 100000); + output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48()); + } + } +} + +void CheckTopkResult(const std::vector &input_ids_1, + const std::vector &input_distance_1, + const std::vector &input_ids_2, + const std::vector &input_distance_2, + uint64_t nq, + uint64_t topk, + bool ascending, + const ms::ResultSet& result) { + ASSERT_EQ(result.size(), nq); + ASSERT_EQ(input_ids_1.size(), input_distance_1.size()); + ASSERT_EQ(input_ids_2.size(), input_distance_2.size()); + + uint64_t input_k1 = input_ids_1.size() / nq; + uint64_t input_k2 = input_ids_2.size() / nq; + + for (int64_t i = 0; i < nq; i++) { + std::vector src_vec(input_distance_1.begin()+i*input_k1, input_distance_1.begin()+(i+1)*input_k1); + src_vec.insert(src_vec.end(), input_distance_2.begin()+i*input_k2, input_distance_2.begin()+(i+1)*input_k2); + if (ascending) { + std::sort(src_vec.begin(), src_vec.end()); + } else { + std::sort(src_vec.begin(), src_vec.end(), std::greater()); + } + + uint64_t n = std::min(topk, input_k1+input_k2); + for (uint64_t j = 0; j < n; j++) { + if (src_vec[j] != result[i][j].second) { + std::cout << src_vec[j] << " " << result[i][j].second << std::endl; + } + ASSERT_TRUE(src_vec[j] == result[i][j].second); + } + } +} + +} // namespace + +TEST(DBSearchTest, TOPK_TEST) { + uint64_t NQ = 15; + uint64_t TOP_K = 64; + bool ascending; + std::vector ids1, ids2; + std::vector dist1, dist2; + ms::ResultSet result; + milvus::Status status; + + /* test1, id1/dist1 valid, id2/dist2 empty */ + ascending = true; + BuildResult(NQ, TOP_K, ascending, ids1, dist1); + status = ms::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test2, id1/dist1 valid, id2/dist2 valid */ + BuildResult(NQ, TOP_K, ascending, ids2, dist2); + status = ms::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test3, id1/dist1 small topk */ + ids1.clear(); + dist1.clear(); + result.clear(); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + status = ms::XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = ms::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test4, id1/dist1 small topk, id2/dist2 small topk */ + ids2.clear(); + dist2.clear(); + result.clear(); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + status = ms::XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = ms::XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + +///////////////////////////////////////////////////////////////////////////////////////// + ascending = false; + ids1.clear(); + dist1.clear(); + ids2.clear(); + dist2.clear(); + result.clear(); + + /* test1, id1/dist1 valid, id2/dist2 empty */ + BuildResult(NQ, TOP_K, ascending, ids1, dist1); + status = ms::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test2, id1/dist1 valid, id2/dist2 valid */ + BuildResult(NQ, TOP_K, ascending, ids2, dist2); + status = ms::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test3, id1/dist1 small topk */ + ids1.clear(); + dist1.clear(); + result.clear(); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + status = ms::XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = ms::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test4, id1/dist1 small topk, id2/dist2 small topk */ + ids2.clear(); + dist2.clear(); + result.clear(); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + status = ms::XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = ms::XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); +} + +TEST(DBSearchTest, REDUCE_PERF_TEST) { + int32_t nq = 100; + int32_t top_k = 1000; + int32_t index_file_num = 478; /* sift1B dataset, index files num */ + bool ascending = true; + std::vector input_ids; + std::vector input_distance; + ms::ResultSet final_result; + milvus::Status status; + + double span, reduce_cost = 0.0; + milvus::TimeRecorder rc(""); + + for (int32_t i = 0; i < index_file_num; i++) { + BuildResult(nq, top_k, ascending, input_ids, input_distance); + + rc.RecordSection("do search for context: " + std::to_string(i)); + + // pick up topk result + status = ms::XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(final_result.size(), nq); + + span = rc.RecordSection("reduce topk for context: " + std::to_string(i)); + reduce_cost += span; + } + std::cout << "total reduce time: " << reduce_cost/1000 << " ms" << std::endl; +} diff --git a/cpp/unittest/db/utils.cpp b/cpp/unittest/db/utils.cpp index 08e3c8f0902d52eec35302c48c7b70e7b8961d6f..c5874be694b46712a61bb787849dd64e2f73b2cd 100644 --- a/cpp/unittest/db/utils.cpp +++ b/cpp/unittest/db/utils.cpp @@ -18,9 +18,11 @@ #include #include +#include +#include #include -#include "utils.h" +#include "db/utils.h" #include "cache/GpuCacheMgr.h" #include "cache/CpuCacheMgr.h" #include "db/DBFactory.h" @@ -29,83 +31,99 @@ INITIALIZE_EASYLOGGINGPP -using namespace zilliz::milvus; +namespace { -static std::string uri; +namespace ms = milvus; class DBTestEnvironment : public ::testing::Environment { -public: - -// explicit DBTestEnvironment(std::string uri) : uri_(uri) {} + public: + explicit DBTestEnvironment(const std::string& uri) + : uri_(uri) { + } - static std::string getURI() { - return uri; + std::string getURI() const { + return uri_; } void SetUp() override { getURI(); } + private: + std::string uri_; }; +DBTestEnvironment* test_env = nullptr; + +} // namespace + + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -void BaseTest::InitLog() { +void +BaseTest::InitLog() { el::Configurations defaultConf; defaultConf.setToDefault(); defaultConf.set(el::Level::Debug, - el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)"); + el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)"); el::Loggers::reconfigureLogger("default", defaultConf); } -void BaseTest::SetUp() { +void +BaseTest::SetUp() { InitLog(); - zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(0, 1024*1024*200, 1024*1024*300, 2); + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(0, 1024 * 1024 * 200, 1024 * 1024 * 300, 2); } -void BaseTest::TearDown() { - zilliz::milvus::cache::CpuCacheMgr::GetInstance()->ClearCache(); - zilliz::milvus::cache::GpuCacheMgr::GetInstance(0)->ClearCache(); - zilliz::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +void +BaseTest::TearDown() { + milvus::cache::CpuCacheMgr::GetInstance()->ClearCache(); + milvus::cache::GpuCacheMgr::GetInstance(0)->ClearCache(); + knowhere::FaissGpuResourceMgr::GetInstance().Free(); } -engine::DBOptions BaseTest::GetOptions() { - auto options = engine::DBFactory::BuildOption(); +ms::engine::DBOptions +BaseTest::GetOptions() { + auto options = ms::engine::DBFactory::BuildOption(); options.meta_.path_ = "/tmp/milvus_test"; options.meta_.backend_uri_ = "sqlite://:@:/"; return options; } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -void DBTest::SetUp() { +void +DBTest::SetUp() { BaseTest::SetUp(); - auto res_mgr = scheduler::ResMgrInst::GetInstance(); + auto res_mgr = ms::scheduler::ResMgrInst::GetInstance(); res_mgr->Clear(); - res_mgr->Add(scheduler::ResourceFactory::Create("disk", "DISK", 0, true, false)); - res_mgr->Add(scheduler::ResourceFactory::Create("cpu", "CPU", 0, true, false)); - res_mgr->Add(scheduler::ResourceFactory::Create("gtx1660", "GPU", 0, true, true)); + res_mgr->Add(ms::scheduler::ResourceFactory::Create("disk", "DISK", 0, true, false)); + res_mgr->Add(ms::scheduler::ResourceFactory::Create("cpu", "CPU", 0, true, false)); + res_mgr->Add(ms::scheduler::ResourceFactory::Create("gtx1660", "GPU", 0, true, true)); - auto default_conn = scheduler::Connection("IO", 500.0); - auto PCIE = scheduler::Connection("IO", 11000.0); + auto default_conn = ms::scheduler::Connection("IO", 500.0); + auto PCIE = ms::scheduler::Connection("IO", 11000.0); res_mgr->Connect("disk", "cpu", default_conn); res_mgr->Connect("cpu", "gtx1660", PCIE); res_mgr->Start(); - scheduler::SchedInst::GetInstance()->Start(); + ms::scheduler::SchedInst::GetInstance()->Start(); - scheduler::JobMgrInst::GetInstance()->Start(); + ms::scheduler::JobMgrInst::GetInstance()->Start(); auto options = GetOptions(); - db_ = engine::DBFactory::Build(options); + db_ = ms::engine::DBFactory::Build(options); } -void DBTest::TearDown() { +void +DBTest::TearDown() { db_->Stop(); db_->DropAll(); - scheduler::JobMgrInst::GetInstance()->Stop(); - scheduler::SchedInst::GetInstance()->Stop(); - scheduler::ResMgrInst::GetInstance()->Stop(); + ms::scheduler::JobMgrInst::GetInstance()->Stop(); + ms::scheduler::SchedInst::GetInstance()->Stop(); + ms::scheduler::ResMgrInst::GetInstance()->Stop(); + ms::scheduler::ResMgrInst::GetInstance()->Clear(); BaseTest::TearDown(); @@ -114,23 +132,26 @@ void DBTest::TearDown() { } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -engine::DBOptions DBTest2::GetOptions() { - auto options = engine::DBFactory::BuildOption(); +ms::engine::DBOptions +DBTest2::GetOptions() { + auto options = ms::engine::DBFactory::BuildOption(); options.meta_.path_ = "/tmp/milvus_test"; - options.meta_.archive_conf_ = engine::ArchiveConf("delete", "disk:1"); + options.meta_.archive_conf_ = ms::engine::ArchiveConf("delete", "disk:1"); options.meta_.backend_uri_ = "sqlite://:@:/"; return options; } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -void MetaTest::SetUp() { +void +MetaTest::SetUp() { BaseTest::SetUp(); auto options = GetOptions(); - impl_ = std::make_shared(options.meta_); + impl_ = std::make_shared(options.meta_); } -void MetaTest::TearDown() { +void +MetaTest::TearDown() { impl_->DropAll(); BaseTest::TearDown(); @@ -140,27 +161,26 @@ void MetaTest::TearDown() { } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -engine::DBOptions MySqlDBTest::GetOptions() { - auto options = engine::DBFactory::BuildOption(); +ms::engine::DBOptions +MySqlDBTest::GetOptions() { + auto options = ms::engine::DBFactory::BuildOption(); options.meta_.path_ = "/tmp/milvus_test"; - options.meta_.backend_uri_ = DBTestEnvironment::getURI(); - - if(options.meta_.backend_uri_.empty()) { - options.meta_.backend_uri_ = "mysql://root:Fantast1c@192.168.1.194:3306/"; - } + options.meta_.backend_uri_ = test_env->getURI(); return options; } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -void MySqlMetaTest::SetUp() { +void +MySqlMetaTest::SetUp() { BaseTest::SetUp(); auto options = GetOptions(); - impl_ = std::make_shared(options.meta_, options.mode_); + impl_ = std::make_shared(options.meta_, options.mode_); } -void MySqlMetaTest::TearDown() { +void +MySqlMetaTest::TearDown() { impl_->DropAll(); BaseTest::TearDown(); @@ -169,29 +189,26 @@ void MySqlMetaTest::TearDown() { boost::filesystem::remove_all(options.meta_.path_); } -engine::DBOptions MySqlMetaTest::GetOptions() { - auto options = engine::DBFactory::BuildOption(); +ms::engine::DBOptions +MySqlMetaTest::GetOptions() { + auto options = ms::engine::DBFactory::BuildOption(); options.meta_.path_ = "/tmp/milvus_test"; - options.meta_.backend_uri_ = DBTestEnvironment::getURI(); - - if(options.meta_.backend_uri_.empty()) { - options.meta_.backend_uri_ = "mysql://root:Fantast1c@192.168.1.194:3306/"; - } + options.meta_.backend_uri_ = test_env->getURI(); return options; } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -int main(int argc, char **argv) { +int +main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); + + std::string uri; if (argc > 1) { uri = argv[1]; } -// if(uri.empty()) { -// uri = "mysql://root:Fantast1c@192.168.1.194:3306/"; -// } -// std::cout << uri << std::endl; - ::testing::AddGlobalTestEnvironment(new DBTestEnvironment); + test_env = new DBTestEnvironment(uri); + ::testing::AddGlobalTestEnvironment(test_env); return RUN_ALL_TESTS(); } diff --git a/cpp/unittest/db/utils.h b/cpp/unittest/db/utils.h index 7a1320ef6b4e57227476a01517d179b4c7d8ae9e..8da160dc92fc60a8ae93b8c0f60d46b8ae6606e6 100644 --- a/cpp/unittest/db/utils.h +++ b/cpp/unittest/db/utils.h @@ -20,7 +20,7 @@ #include #include -//#include +#include #include "db/DB.h" #include "db/meta/SqliteMetaImpl.h" @@ -28,7 +28,6 @@ #include "scheduler/SchedInst.h" #include "scheduler/ResourceFactory.h" - #define TIMING #ifdef TIMING @@ -36,8 +35,7 @@ #define START_TIMER start = std::chrono::high_resolution_clock::now(); #define STOP_TIMER(name) LOG(DEBUG) << "RUNTIME of " << name << ": " << \ std::chrono::duration_cast( \ - std::chrono::high_resolution_clock::now()-start \ - ).count() << " ms "; + std::chrono::high_resolution_clock::now()-start).count() << " ms "; #else #define INIT_TIMER #define START_TIMER @@ -45,27 +43,27 @@ #endif class BaseTest : public ::testing::Test { -protected: + protected: void InitLog(); - virtual void SetUp() override; - virtual void TearDown() override; - virtual zilliz::milvus::engine::DBOptions GetOptions(); + void SetUp() override; + void TearDown() override; + virtual milvus::engine::DBOptions GetOptions(); }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// class DBTest : public BaseTest { protected: - zilliz::milvus::engine::DBPtr db_; + milvus::engine::DBPtr db_; - virtual void SetUp() override; - virtual void TearDown() override; + void SetUp() override; + void TearDown() override; }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// class DBTest2 : public DBTest { protected: - virtual zilliz::milvus::engine::DBOptions GetOptions() override; + milvus::engine::DBOptions GetOptions() override; }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -75,29 +73,28 @@ class EngineTest : public DBTest { ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// class MetaTest : public BaseTest { protected: - std::shared_ptr impl_; + std::shared_ptr impl_; - virtual void SetUp() override; - virtual void TearDown() override; + void SetUp() override; + void TearDown() override; }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// class MySqlDBTest : public DBTest { -protected: - zilliz::milvus::engine::DBOptions GetOptions(); + protected: + milvus::engine::DBOptions GetOptions() override; }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// class MySqlMetaTest : public BaseTest { protected: - std::shared_ptr impl_; + std::shared_ptr impl_; - virtual void SetUp() override; - virtual void TearDown() override; - zilliz::milvus::engine::DBOptions GetOptions(); + void SetUp() override; + void TearDown() override; + milvus::engine::DBOptions GetOptions() override; }; - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// class MemManagerTest : public MetaTest { }; diff --git a/cpp/unittest/main.cpp b/cpp/unittest/main.cpp index 7a28bbf45ffafc12393dc6ec3d3d00c295fad779..d17cf9da581672991eb421b6e6f2ddf1cfa7462c 100644 --- a/cpp/unittest/main.cpp +++ b/cpp/unittest/main.cpp @@ -19,13 +19,11 @@ #include #include "utils/easylogging++.h" -#include "utils/CommonUtil.h" INITIALIZE_EASYLOGGINGPP -using namespace zilliz::milvus; - -int main(int argc, char **argv) { +int +main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/cpp/unittest/metrics/CMakeLists.txt b/cpp/unittest/metrics/CMakeLists.txt index 04eab826bc3f3451516cba92965f4219fb571011..eba8146baa7b0a5b225aed8f16976830b29e0c60 100644 --- a/cpp/unittest/metrics/CMakeLists.txt +++ b/cpp/unittest/metrics/CMakeLists.txt @@ -19,16 +19,13 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} test_files) -include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include") -link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") - -set(metrics_test_files +add_executable(test_metrics ${common_files} ${test_files} ) -add_executable(metrics_test ${metrics_test_files}) - -target_link_libraries(metrics_test knowhere ${unittest_libs}) +target_link_libraries(test_metrics + knowhere + ${unittest_libs}) -install(TARGETS metrics_test DESTINATION unittest) \ No newline at end of file +install(TARGETS test_metrics DESTINATION unittest) \ No newline at end of file diff --git a/cpp/unittest/metrics/metricbase_test.cpp b/cpp/unittest/metrics/test_metricbase.cpp similarity index 90% rename from cpp/unittest/metrics/metricbase_test.cpp rename to cpp/unittest/metrics/test_metricbase.cpp index 31766403be1227b08cc8e5081721ac8610c013f0..4edb3b81515eeac45181a20780af9f1c3da4e16a 100644 --- a/cpp/unittest/metrics/metricbase_test.cpp +++ b/cpp/unittest/metrics/test_metricbase.cpp @@ -21,12 +21,16 @@ #include #include -using namespace zilliz::milvus; +namespace { -TEST(MetricbaseTest, METRICBASE_TEST){ - server::MetricsBase instance = server::MetricsBase::GetInstance(); +namespace ms = milvus; + +} // namespace + +TEST(MetricbaseTest, METRICBASE_TEST) { + ms::server::MetricsBase instance = ms::server::MetricsBase::GetInstance(); instance.Init(); - server::SystemInfo::GetInstance().Init(); + ms::server::SystemInfo::GetInstance().Init(); instance.AddVectorsSuccessTotalIncrement(); instance.AddVectorsFailTotalIncrement(); instance.AddVectorsDurationHistogramOberve(1.0); @@ -60,10 +64,10 @@ TEST(MetricbaseTest, METRICBASE_TEST){ instance.QueryResponsePerSecondGaugeSet(1.0); instance.GPUPercentGaugeSet(); instance.GPUMemoryUsageGaugeSet(); - instance.AddVectorsPerSecondGaugeSet(1,1,1); + instance.AddVectorsPerSecondGaugeSet(1, 1, 1); instance.QueryIndexTypePerSecondSet("IVF", 1.0); instance.ConnectionGaugeIncrement(); instance.ConnectionGaugeDecrement(); instance.KeepingAliveCounterIncrement(); instance.OctetsSet(); -} \ No newline at end of file +} diff --git a/cpp/unittest/metrics/metrics_test.cpp b/cpp/unittest/metrics/test_metrics.cpp similarity index 51% rename from cpp/unittest/metrics/metrics_test.cpp rename to cpp/unittest/metrics/test_metrics.cpp index 14c1ead3f33715d103aa2cea555acf0943d8747d..ba0b51af00b6ef950af36087b3582bd2a830bcae 100644 --- a/cpp/unittest/metrics/metrics_test.cpp +++ b/cpp/unittest/metrics/test_metrics.cpp @@ -15,75 +15,73 @@ // specific language governing permissions and limitations // under the License. -#include #include #include #include #include #include #include -//#include "prometheus/registry.h" -//#include "prometheus/exposer.h" -#include -#include +#include "cache/CpuCacheMgr.h" +#include "server/Config.h" #include "metrics/Metrics.h" -#include "utils.h" +#include "metrics/utils.h" #include "db/DB.h" #include "db/meta/SqliteMetaImpl.h" +namespace { -using namespace zilliz::milvus; +namespace ms = milvus; +} // namespace TEST_F(MetricTest, METRIC_TEST) { - server::Config::GetInstance().SetMetricConfigCollector("zabbix"); - server::Metrics::GetInstance(); - server::Config::GetInstance().SetMetricConfigCollector("prometheus"); - server::Metrics::GetInstance(); + ms::server::Config::GetInstance().SetMetricConfigCollector("zabbix"); + ms::server::Metrics::GetInstance(); + ms::server::Config::GetInstance().SetMetricConfigCollector("prometheus"); + ms::server::Metrics::GetInstance(); - server::SystemInfo::GetInstance().Init(); + ms::server::SystemInfo::GetInstance().Init(); // server::Metrics::GetInstance().Init(); // server::Metrics::GetInstance().exposer_ptr()->RegisterCollectable(server::Metrics::GetInstance().registry_ptr()); - server::Metrics::GetInstance().Init(); + ms::server::Metrics::GetInstance().Init(); // server::PrometheusMetrics::GetInstance().exposer_ptr()->RegisterCollectable(server::PrometheusMetrics::GetInstance().registry_ptr()); - zilliz::milvus::cache::CpuCacheMgr::GetInstance()->SetCapacity(1UL*1024*1024*1024); - std::cout<CacheCapacity()<SetCapacity(1UL * 1024 * 1024 * 1024); + std::cout << milvus::cache::CpuCacheMgr::GetInstance()->CacheCapacity() << std::endl; - static const char* group_name = "test_group"; + static const char *group_name = "test_group"; static const int group_dim = 256; - engine::meta::TableSchema group_info; + ms::engine::meta::TableSchema group_info; group_info.dimension_ = group_dim; group_info.table_id_ = group_name; auto stat = db_->CreateTable(group_info); - engine::meta::TableSchema group_info_get; + ms::engine::meta::TableSchema group_info_get; group_info_get.table_id_ = group_name; stat = db_->DescribeTable(group_info_get); - - engine::IDNumbers vector_ids; - engine::IDNumbers target_ids; + ms::engine::IDNumbers vector_ids; + ms::engine::IDNumbers target_ids; int d = 256; int nb = 50; float *xb = new float[d * nb]; - for(int i = 0; i < nb; i++) { - for(int j = 0; j < d; j++) xb[d * i + j] = drand48(); + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) xb[d * i + j] = drand48(); xb[d * i] += i / 2000.; } int qb = 5; float *qxb = new float[d * qb]; - for(int i = 0; i < qb; i++) { - for(int j = 0; j < d; j++) qxb[d * i + j] = drand48(); + for (int i = 0; i < qb; i++) { + for (int j = 0; j < d; j++) qxb[d * i + j] = drand48(); qxb[d * i] += i / 2000.; } std::thread search([&]() { - engine::QueryResults results; + ms::engine::QueryResults results; int k = 10; std::this_thread::sleep_for(std::chrono::seconds(2)); @@ -92,23 +90,23 @@ TEST_F(MetricTest, METRIC_TEST) { uint64_t count = 0; uint64_t prev_count = 0; - for (auto j=0; j<10; ++j) { + for (auto j = 0; j < 10; ++j) { ss.str(""); db_->Size(count); prev_count = count; START_TIMER; // stat = db_->Query(group_name, k, qb, qxb, results); - ss << "Search " << j << " With Size " << (float)(count*group_dim*sizeof(float))/(1024*1024) << " M"; + ss << "Search " << j << " With Size " << (float) (count * group_dim * sizeof(float)) / (1024 * 1024) + << " M"; - for (auto k=0; k= prev_count); std::this_thread::sleep_for(std::chrono::seconds(1)); @@ -117,8 +115,8 @@ TEST_F(MetricTest, METRIC_TEST) { int loop = 10000; - for (auto i=0; iInsertVectors(group_name, qb, qxb, target_ids); ASSERT_EQ(target_ids.size(), qb); } else { @@ -129,37 +127,37 @@ TEST_F(MetricTest, METRIC_TEST) { search.join(); - delete [] xb; - delete [] qxb; -}; + delete[] xb; + delete[] qxb; +} -TEST_F(MetricTest, COLLECTOR_METRICS_TEST){ - auto status = Status::OK(); - server::CollectInsertMetrics insert_metrics0(0, status); - status = Status(DB_ERROR, "error"); - server::CollectInsertMetrics insert_metrics1(0, status); +TEST_F(MetricTest, COLLECTOR_METRICS_TEST) { + auto status = ms::Status::OK(); + ms::server::CollectInsertMetrics insert_metrics0(0, status); + status = ms::Status(ms::DB_ERROR, "error"); + ms::server::CollectInsertMetrics insert_metrics1(0, status); - server::CollectQueryMetrics query_metrics(10); + ms::server::CollectQueryMetrics query_metrics(10); - server::CollectMergeFilesMetrics merge_metrics(); + ms::server::CollectMergeFilesMetrics merge_metrics(); - server::CollectBuildIndexMetrics build_index_metrics(); + ms::server::CollectBuildIndexMetrics build_index_metrics(); - server::CollectExecutionEngineMetrics execution_metrics(10); + ms::server::CollectExecutionEngineMetrics execution_metrics(10); - server::CollectSerializeMetrics serialize_metrics(10); + ms::server::CollectSerializeMetrics serialize_metrics(10); - server::CollectAddMetrics add_metrics(10, 128); + ms::server::CollectAddMetrics add_metrics(10, 128); - server::CollectDurationMetrics duration_metrics_raw(engine::meta::TableFileSchema::RAW); - server::CollectDurationMetrics duration_metrics_index(engine::meta::TableFileSchema::TO_INDEX); - server::CollectDurationMetrics duration_metrics_delete(engine::meta::TableFileSchema::TO_DELETE); + ms::server::CollectDurationMetrics duration_metrics_raw(ms::engine::meta::TableFileSchema::RAW); + ms::server::CollectDurationMetrics duration_metrics_index(ms::engine::meta::TableFileSchema::TO_INDEX); + ms::server::CollectDurationMetrics duration_metrics_delete(ms::engine::meta::TableFileSchema::TO_DELETE); - server::CollectSearchTaskMetrics search_metrics_raw(engine::meta::TableFileSchema::RAW); - server::CollectSearchTaskMetrics search_metrics_index(engine::meta::TableFileSchema::TO_INDEX); - server::CollectSearchTaskMetrics search_metrics_delete(engine::meta::TableFileSchema::TO_DELETE); + ms::server::CollectSearchTaskMetrics search_metrics_raw(ms::engine::meta::TableFileSchema::RAW); + ms::server::CollectSearchTaskMetrics search_metrics_index(ms::engine::meta::TableFileSchema::TO_INDEX); + ms::server::CollectSearchTaskMetrics search_metrics_delete(ms::engine::meta::TableFileSchema::TO_DELETE); - server::MetricCollector metric_collector(); + ms::server::MetricCollector metric_collector(); } diff --git a/cpp/unittest/metrics/prometheus_test.cpp b/cpp/unittest/metrics/test_prometheus.cpp similarity index 87% rename from cpp/unittest/metrics/prometheus_test.cpp rename to cpp/unittest/metrics/test_prometheus.cpp index a634a6ff9c848efe9a1bca8c513ef8cdae8909b1..14982d058d154bf248bc352068af65eb43107d41 100644 --- a/cpp/unittest/metrics/prometheus_test.cpp +++ b/cpp/unittest/metrics/test_prometheus.cpp @@ -22,15 +22,19 @@ #include #include -using namespace zilliz::milvus; +namespace { -TEST(PrometheusTest, PROMETHEUS_TEST){ - server::Config::GetInstance().SetMetricConfigEnableMonitor("on"); +namespace ms = milvus; - server::PrometheusMetrics instance = server::PrometheusMetrics::GetInstance(); +} // namespace + +TEST(PrometheusTest, PROMETHEUS_TEST) { + ms::server::Config::GetInstance().SetMetricConfigEnableMonitor("on"); + + ms::server::PrometheusMetrics instance = ms::server::PrometheusMetrics::GetInstance(); instance.Init(); instance.SetStartup(true); - server::SystemInfo::GetInstance().Init(); + ms::server::SystemInfo::GetInstance().Init(); instance.AddVectorsSuccessTotalIncrement(); instance.AddVectorsFailTotalIncrement(); instance.AddVectorsDurationHistogramOberve(1.0); @@ -64,7 +68,7 @@ TEST(PrometheusTest, PROMETHEUS_TEST){ instance.QueryResponsePerSecondGaugeSet(1.0); instance.GPUPercentGaugeSet(); instance.GPUMemoryUsageGaugeSet(); - instance.AddVectorsPerSecondGaugeSet(1,1,1); + instance.AddVectorsPerSecondGaugeSet(1, 1, 1); instance.QueryIndexTypePerSecondSet("IVF", 1.0); instance.QueryIndexTypePerSecondSet("IDMap", 1.0); instance.ConnectionGaugeIncrement(); @@ -76,10 +80,9 @@ TEST(PrometheusTest, PROMETHEUS_TEST){ instance.GPUTemperature(); instance.CPUTemperature(); - server::Config::GetInstance().SetMetricConfigEnableMonitor("off"); + ms::server::Config::GetInstance().SetMetricConfigEnableMonitor("off"); instance.Init(); instance.CPUCoreUsagePercentSet(); instance.GPUTemperature(); instance.CPUTemperature(); - -} \ No newline at end of file +} diff --git a/cpp/unittest/metrics/utils.cpp b/cpp/unittest/metrics/utils.cpp index 8d59a1e9c2ac8171bd728fbb8fa6816be8fe7a2a..e345923b7b59ef094e23e101fd102c920f7ea894 100644 --- a/cpp/unittest/metrics/utils.cpp +++ b/cpp/unittest/metrics/utils.cpp @@ -18,32 +18,38 @@ #include #include +#include #include -#include "utils.h" +#include "metrics/utils.h" #include "db/DBFactory.h" INITIALIZE_EASYLOGGINGPP -using namespace zilliz::milvus; +namespace { -static std::string uri; +namespace ms = milvus; class DBTestEnvironment : public ::testing::Environment { -public: + public: + explicit DBTestEnvironment(const std::string& uri) : uri_(uri) {} -// explicit DBTestEnvironment(std::string uri) : uri_(uri) {} - - static std::string getURI() { - return uri; + std::string getURI() const { + return uri_; } void SetUp() override { getURI(); } + private: + std::string uri_; }; +DBTestEnvironment* test_env = nullptr; + +} // namespace + void MetricTest::InitLog() { el::Configurations defaultConf; defaultConf.setToDefault(); @@ -52,17 +58,18 @@ void MetricTest::InitLog() { el::Loggers::reconfigureLogger("default", defaultConf); } -engine::DBOptions MetricTest::GetOptions() { - auto options = engine::DBFactory::BuildOption(); +ms::engine::DBOptions MetricTest::GetOptions() { + auto options = ms::engine::DBFactory::BuildOption(); options.meta_.path_ = "/tmp/milvus_test"; options.meta_.backend_uri_ = "sqlite://:@:/"; return options; } void MetricTest::SetUp() { + boost::filesystem::remove_all("/tmp/milvus_test"); InitLog(); auto options = GetOptions(); - db_ = engine::DBFactory::Build(options); + db_ = ms::engine::DBFactory::Build(options); } void MetricTest::TearDown() { @@ -72,10 +79,12 @@ void MetricTest::TearDown() { int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); + + std::string uri; if (argc > 1) { uri = argv[1]; } -// std::cout << uri << std::endl; - ::testing::AddGlobalTestEnvironment(new DBTestEnvironment); + test_env = new DBTestEnvironment(uri); + ::testing::AddGlobalTestEnvironment(test_env); return RUN_ALL_TESTS(); } diff --git a/cpp/unittest/metrics/utils.h b/cpp/unittest/metrics/utils.h index 37c21c82f93a19a59ff5c5ceee4c0ed588597e2b..996dec89f59e97a4119b25b5ea4f7a296f0006a4 100644 --- a/cpp/unittest/metrics/utils.h +++ b/cpp/unittest/metrics/utils.h @@ -26,7 +26,6 @@ #include "db/meta/SqliteMetaImpl.h" #include "db/meta/MySQLMetaImpl.h" - #define TIMING #ifdef TIMING @@ -34,15 +33,15 @@ #define START_TIMER start = std::chrono::high_resolution_clock::now(); #define STOP_TIMER(name) LOG(DEBUG) << "RUNTIME of " << name << ": " << \ std::chrono::duration_cast( \ - std::chrono::high_resolution_clock::now()-start \ - ).count() << " ms "; + std::chrono::high_resolution_clock::now()-start).count() << " ms "; #else #define INIT_TIMER #define START_TIMER #define STOP_TIMER(name) #endif -void ASSERT_STATS(zilliz::milvus::Status& stat); +void +ASSERT_STATS(milvus::Status &stat); //class TestEnv : public ::testing::Environment { //public: @@ -66,11 +65,11 @@ void ASSERT_STATS(zilliz::milvus::Status& stat); // ::testing::AddGlobalTestEnvironment(new TestEnv); class MetricTest : public ::testing::Test { -protected: - zilliz::milvus::engine::DBPtr db_; + protected: + milvus::engine::DBPtr db_; void InitLog(); - virtual void SetUp() override; - virtual void TearDown() override; - virtual zilliz::milvus::engine::DBOptions GetOptions(); -}; \ No newline at end of file + void SetUp() override; + void TearDown() override; + virtual milvus::engine::DBOptions GetOptions(); +}; diff --git a/cpp/unittest/scheduler/CMakeLists.txt b/cpp/unittest/scheduler/CMakeLists.txt index 27ee11775c2420c8ab25ebfaef84c920c43b8603..087f93f017c38f7b9858f46f17abe0f03bc0b7ad 100644 --- a/cpp/unittest/scheduler/CMakeLists.txt +++ b/cpp/unittest/scheduler/CMakeLists.txt @@ -19,18 +19,15 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} test_files) -include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include") -link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") - -set(scheduler_test_files +cuda_add_executable(test_scheduler ${common_files} - ${unittest_files} + ${entry_file} ${test_files} ) -cuda_add_executable(scheduler_test ${scheduler_test_files}) - -target_link_libraries(scheduler_test knowhere ${unittest_libs}) +target_link_libraries(test_scheduler + knowhere + ${unittest_libs}) -install(TARGETS scheduler_test DESTINATION unittest) +install(TARGETS test_scheduler DESTINATION unittest) diff --git a/cpp/unittest/scheduler/task_test.cpp b/cpp/unittest/scheduler/task_test.cpp index 17d1e28cd69da7674686a4ce2090be5f751878f4..07e85c723c65808978283f9b244cd9d95d44c95a 100644 --- a/cpp/unittest/scheduler/task_test.cpp +++ b/cpp/unittest/scheduler/task_test.cpp @@ -20,16 +20,15 @@ #include -namespace zilliz { namespace milvus { namespace scheduler { - TEST(TaskTest, INVALID_INDEX) { - auto search_task = std::make_shared(nullptr); + auto search_task = std::make_shared(nullptr, nullptr); search_task->Load(LoadType::TEST, 10); } -} -} -} +} // namespace scheduler +} // namespace milvus + + diff --git a/cpp/unittest/scheduler/algorithm_test.cpp b/cpp/unittest/scheduler/test_algorithm.cpp similarity index 98% rename from cpp/unittest/scheduler/algorithm_test.cpp rename to cpp/unittest/scheduler/test_algorithm.cpp index fe619c4981f5f71d648ee812c1a3da20ee6da592..bcb90873735e3b361e44bf83ec70ac9322a91e6b 100644 --- a/cpp/unittest/scheduler/algorithm_test.cpp +++ b/cpp/unittest/scheduler/test_algorithm.cpp @@ -24,7 +24,7 @@ #include "scheduler/ResourceFactory.h" #include "scheduler/Algorithm.h" -namespace zilliz { + namespace milvus { namespace scheduler { @@ -101,11 +101,8 @@ TEST_F(AlgorithmTest, SHORTESTPATH_TEST) { std::cout << sp[sp.size() - 1] << std::endl; sp.pop_back(); } - - } +} // namespace scheduler +} // namespace milvus -} -} -} \ No newline at end of file diff --git a/cpp/unittest/scheduler/event_test.cpp b/cpp/unittest/scheduler/test_event.cpp similarity index 97% rename from cpp/unittest/scheduler/event_test.cpp rename to cpp/unittest/scheduler/test_event.cpp index 34d4c2ce23878f8e8fd0de5e4bb73be3b0fa685c..07d51e8557730740614ed5485d4daea6c0489e71 100644 --- a/cpp/unittest/scheduler/event_test.cpp +++ b/cpp/unittest/scheduler/test_event.cpp @@ -23,7 +23,7 @@ #include "scheduler/event/StartUpEvent.h" -namespace zilliz { + namespace milvus { namespace scheduler { @@ -60,7 +60,7 @@ TEST(EventTest, TASKTABLE_UPDATED_EVENT) { std::cout << *EventPtr(event); } -} -} -} +} // namespace scheduler +} // namespace milvus + diff --git a/cpp/unittest/scheduler/node_test.cpp b/cpp/unittest/scheduler/test_node.cpp similarity index 81% rename from cpp/unittest/scheduler/node_test.cpp rename to cpp/unittest/scheduler/test_node.cpp index b87611d3ca2ec347a59ef90e3abc8a27da02f9f5..9b34b73191d25f9f59ec88537c24a1668a87d8af 100644 --- a/cpp/unittest/scheduler/node_test.cpp +++ b/cpp/unittest/scheduler/test_node.cpp @@ -19,37 +19,40 @@ #include "scheduler/resource/Node.h" #include +namespace { -using namespace zilliz::milvus::scheduler; +namespace ms = milvus::scheduler; + +} // namespace class NodeTest : public ::testing::Test { -protected: + protected: void SetUp() override { - node1_ = std::make_shared(); - node2_ = std::make_shared(); - node3_ = std::make_shared(); - isolated_node1_ = std::make_shared(); - isolated_node2_ = std::make_shared(); + node1_ = std::make_shared(); + node2_ = std::make_shared(); + node3_ = std::make_shared(); + isolated_node1_ = std::make_shared(); + isolated_node2_ = std::make_shared(); - auto pcie = Connection("PCIe", 11.0); + auto pcie = ms::Connection("PCIe", 11.0); node1_->AddNeighbour(node2_, pcie); node1_->AddNeighbour(node3_, pcie); node2_->AddNeighbour(node1_, pcie); } - NodePtr node1_; - NodePtr node2_; - NodePtr node3_; - NodePtr isolated_node1_; - NodePtr isolated_node2_; + ms::NodePtr node1_; + ms::NodePtr node2_; + ms::NodePtr node3_; + ms::NodePtr isolated_node1_; + ms::NodePtr isolated_node2_; }; TEST_F(NodeTest, ADD_NEIGHBOUR) { ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 0); ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0); - auto pcie = Connection("PCIe", 11.0); + auto pcie = ms::Connection("PCIe", 11.0); isolated_node1_->AddNeighbour(isolated_node2_, pcie); ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 1); ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0); @@ -58,7 +61,7 @@ TEST_F(NodeTest, ADD_NEIGHBOUR) { TEST_F(NodeTest, REPEAT_ADD_NEIGHBOUR) { ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 0); ASSERT_EQ(isolated_node2_->GetNeighbours().size(), 0); - auto pcie = Connection("PCIe", 11.0); + auto pcie = ms::Connection("PCIe", 11.0); isolated_node1_->AddNeighbour(isolated_node2_, pcie); isolated_node1_->AddNeighbour(isolated_node2_, pcie); ASSERT_EQ(isolated_node1_->GetNeighbours().size(), 1); @@ -97,3 +100,4 @@ TEST_F(NodeTest, DUMP) { std::cout << node2_->Dump(); ASSERT_FALSE(node2_->Dump().empty()); } + diff --git a/cpp/unittest/scheduler/normal_test.cpp b/cpp/unittest/scheduler/test_normal.cpp similarity index 72% rename from cpp/unittest/scheduler/normal_test.cpp rename to cpp/unittest/scheduler/test_normal.cpp index 535fa967315be297baf8982925c4f5e3a76676f2..1dbd93e0449a5a95bf1390e8d670d719947b73bd 100644 --- a/cpp/unittest/scheduler/normal_test.cpp +++ b/cpp/unittest/scheduler/test_normal.cpp @@ -25,35 +25,38 @@ #include "utils/Log.h" #include +namespace { -using namespace zilliz::milvus::scheduler; +namespace ms = milvus::scheduler; +} // namespace TEST(NormalTest, INST_TEST) { // ResourceMgr only compose resources, provide unified event - auto res_mgr = ResMgrInst::GetInstance(); + auto res_mgr = ms::ResMgrInst::GetInstance(); - res_mgr->Add(ResourceFactory::Create("disk", "DISK", 0, true, false)); - res_mgr->Add(ResourceFactory::Create("cpu", "CPU", 0, true, true)); + res_mgr->Add(ms::ResourceFactory::Create("disk", "DISK", 0, true, false)); + res_mgr->Add(ms::ResourceFactory::Create("cpu", "CPU", 0, true, true)); - auto IO = Connection("IO", 500.0); + auto IO = ms::Connection("IO", 500.0); res_mgr->Connect("disk", "cpu", IO); - auto scheduler = SchedInst::GetInstance(); + auto scheduler = ms::SchedInst::GetInstance(); res_mgr->Start(); scheduler->Start(); const uint64_t NUM_TASK = 1000; - std::vector> tasks; - TableFileSchemaPtr dummy = nullptr; + std::vector> tasks; + ms::TableFileSchemaPtr dummy = nullptr; auto disks = res_mgr->GetDiskResources(); ASSERT_FALSE(disks.empty()); if (auto observe = disks[0].lock()) { for (uint64_t i = 0; i < NUM_TASK; ++i) { - auto task = std::make_shared(dummy); - task->label() = std::make_shared(); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); + task->label() = std::make_shared(); tasks.push_back(task); observe->task_table().Put(task); } @@ -67,5 +70,4 @@ TEST(NormalTest, INST_TEST) { scheduler->Stop(); res_mgr->Stop(); - } diff --git a/cpp/unittest/scheduler/resource_test.cpp b/cpp/unittest/scheduler/test_resource.cpp similarity index 88% rename from cpp/unittest/scheduler/resource_test.cpp rename to cpp/unittest/scheduler/test_resource.cpp index b335a601db66421bb3eedb51f6ec7840156f1bca..31fe425959a1312afe987e9a81f4462ad4f1e2b4 100644 --- a/cpp/unittest/scheduler/resource_test.cpp +++ b/cpp/unittest/scheduler/test_resource.cpp @@ -23,23 +23,23 @@ #include "scheduler/resource/TestResource.h" #include "scheduler/task/Task.h" #include "scheduler/task/TestTask.h" +#include "scheduler/tasklabel/DefaultLabel.h" #include "scheduler/ResourceFactory.h" #include -namespace zilliz { namespace milvus { namespace scheduler { /************ ResourceBaseTest ************/ class ResourceBaseTest : public testing::Test { -protected: + protected: void SetUp() override { - only_loader_ = std::make_shared(name1, id1, true, false); - only_executor_ = std::make_shared(name2, id2, false, true); - both_enable_ = std::make_shared(name3, id3, true, true); - both_disable_ = std::make_shared(name4, id4, false, false); + only_loader_ = std::make_shared(name1, id1, true, false); + only_executor_ = std::make_shared(name2, id2, false, true); + both_enable_ = std::make_shared(name3, id3, true, true); + both_disable_ = std::make_shared(name4, id4, false, false); } const std::string name1 = "only_loader_"; @@ -104,7 +104,7 @@ TEST_F(ResourceBaseTest, DUMP) { /************ ResourceAdvanceTest ************/ class ResourceAdvanceTest : public testing::Test { -protected: + protected: void SetUp() override { disk_resource_ = ResourceFactory::Create("ssd", "DISK", 0); @@ -156,13 +156,17 @@ protected: void WaitLoader(uint64_t count) { std::unique_lock lock(load_mutex_); - cv_.wait(lock, [&] { return load_count_ == count; }); + cv_.wait(lock, [&] { + return load_count_ == count; + }); } void WaitExecutor(uint64_t count) { std::unique_lock lock(exec_mutex_); - cv_.wait(lock, [&] { return exec_count_ == count; }); + cv_.wait(lock, [&] { + return exec_count_ == count; + }); } ResourcePtr disk_resource_; @@ -182,7 +186,8 @@ TEST_F(ResourceAdvanceTest, DISK_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); disk_resource_->task_table().Put(task); } @@ -207,7 +212,8 @@ TEST_F(ResourceAdvanceTest, CPU_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); cpu_resource_->task_table().Put(task); } @@ -232,7 +238,8 @@ TEST_F(ResourceAdvanceTest, GPU_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); gpu_resource_->task_table().Put(task); } @@ -257,7 +264,8 @@ TEST_F(ResourceAdvanceTest, TEST_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); test_resource_->task_table().Put(task); } @@ -277,6 +285,6 @@ TEST_F(ResourceAdvanceTest, TEST_RESOURCE_TEST) { } } -} -} -} +} // namespace scheduler +} // namespace milvus + diff --git a/cpp/unittest/scheduler/resource_factory_test.cpp b/cpp/unittest/scheduler/test_resource_factory.cpp similarity index 67% rename from cpp/unittest/scheduler/resource_factory_test.cpp rename to cpp/unittest/scheduler/test_resource_factory.cpp index e2a84257ebec57ff925ef2769ee7a3bffbeea32d..aaad3aa2c99d63b208778116c4e3d940af77c5e8 100644 --- a/cpp/unittest/scheduler/resource_factory_test.cpp +++ b/cpp/unittest/scheduler/test_resource_factory.cpp @@ -19,15 +19,18 @@ #include "scheduler/ResourceFactory.h" #include +namespace { -using namespace zilliz::milvus::scheduler; +namespace ms = milvus::scheduler; + +} // namespace TEST(ResourceFactoryTest, CREATE) { - auto disk = ResourceFactory::Create("ssd", "DISK", 0); - auto cpu = ResourceFactory::Create("cpu", "CPU", 0); - auto gpu = ResourceFactory::Create("gpu", "GPU", 0); + auto disk = ms::ResourceFactory::Create("ssd", "DISK", 0); + auto cpu = ms::ResourceFactory::Create("cpu", "CPU", 0); + auto gpu = ms::ResourceFactory::Create("gpu", "GPU", 0); - ASSERT_TRUE(std::dynamic_pointer_cast(disk)); - ASSERT_TRUE(std::dynamic_pointer_cast(cpu)); - ASSERT_TRUE(std::dynamic_pointer_cast(gpu)); + ASSERT_TRUE(std::dynamic_pointer_cast(disk)); + ASSERT_TRUE(std::dynamic_pointer_cast(cpu)); + ASSERT_TRUE(std::dynamic_pointer_cast(gpu)); } diff --git a/cpp/unittest/scheduler/resource_mgr_test.cpp b/cpp/unittest/scheduler/test_resource_mgr.cpp similarity index 96% rename from cpp/unittest/scheduler/resource_mgr_test.cpp rename to cpp/unittest/scheduler/test_resource_mgr.cpp index c2be785a005497ced14288684de146067fa01479..34e6b50c49fe4df11975a21e860177194c1155d9 100644 --- a/cpp/unittest/scheduler/resource_mgr_test.cpp +++ b/cpp/unittest/scheduler/test_resource_mgr.cpp @@ -21,18 +21,17 @@ #include "scheduler/resource/DiskResource.h" #include "scheduler/resource/TestResource.h" #include "scheduler/task/TestTask.h" +#include "scheduler/tasklabel/DefaultLabel.h" #include "scheduler/ResourceMgr.h" #include -namespace zilliz { namespace milvus { namespace scheduler { - /************ ResourceMgrBaseTest ************/ class ResourceMgrBaseTest : public testing::Test { -protected: + protected: void SetUp() override { empty_mgr_ = std::make_shared(); @@ -77,7 +76,6 @@ TEST_F(ResourceMgrBaseTest, CONNECT) { ASSERT_TRUE(empty_mgr_->Connect("resource1", "resource2", io)); } - TEST_F(ResourceMgrBaseTest, INVALID_CONNECT) { auto resource1 = std::make_shared("resource1", 0, true, true); auto resource2 = std::make_shared("resource2", 2, true, true); @@ -87,7 +85,6 @@ TEST_F(ResourceMgrBaseTest, INVALID_CONNECT) { ASSERT_FALSE(empty_mgr_->Connect("xx", "yy", io)); } - TEST_F(ResourceMgrBaseTest, CLEAR) { ASSERT_EQ(mgr1_->GetNumOfResource(), 3); mgr1_->Clear(); @@ -163,7 +160,7 @@ TEST_F(ResourceMgrBaseTest, DUMP_TASKTABLES) { /************ ResourceMgrAdvanceTest ************/ class ResourceMgrAdvanceTest : public testing::Test { - protected: + protected: void SetUp() override { mgr1_ = std::make_shared(); @@ -188,12 +185,12 @@ TEST_F(ResourceMgrAdvanceTest, REGISTER_SUBSCRIBER) { }; mgr1_->RegisterSubscriber(callback); TableFileSchemaPtr dummy = nullptr; - disk_res->task_table().Put(std::make_shared(dummy)); + auto label = std::make_shared(); + disk_res->task_table().Put(std::make_shared(dummy, label)); sleep(1); ASSERT_TRUE(flag); } +} // namespace scheduler +} // namespace milvus -} -} -} diff --git a/cpp/unittest/scheduler/schedinst_test.cpp b/cpp/unittest/scheduler/test_schedinst.cpp similarity index 97% rename from cpp/unittest/scheduler/schedinst_test.cpp rename to cpp/unittest/scheduler/test_schedinst.cpp index cf1003f5ce591a46acf593e46bc1b1893dcab3c3..e63a9615bcb37dfaef14987fe515cfc89203d581 100644 --- a/cpp/unittest/scheduler/schedinst_test.cpp +++ b/cpp/unittest/scheduler/test_schedinst.cpp @@ -21,14 +21,12 @@ #include "server/Config.h" #include "scheduler/SchedInst.h" -namespace zilliz { + namespace milvus { namespace scheduler { - class SchedInstTest : public testing::Test { - -protected: + protected: void SetUp() override { boost::filesystem::create_directory(TMP_DIR); @@ -83,6 +81,8 @@ TEST_F(SchedInstTest, SIMPLE_GPU) { StartSchedulerService(); } -} -} -} +} // namespace scheduler +} // namespace milvus + + + diff --git a/cpp/unittest/scheduler/scheduler_test.cpp b/cpp/unittest/scheduler/test_scheduler.cpp similarity index 84% rename from cpp/unittest/scheduler/scheduler_test.cpp rename to cpp/unittest/scheduler/test_scheduler.cpp index 8fc65f7d3b59b7d7ff3562f1893370c059ab6a24..1238f906d1e2e3246d47ec7eed750032bef494bb 100644 --- a/cpp/unittest/scheduler/scheduler_test.cpp +++ b/cpp/unittest/scheduler/test_scheduler.cpp @@ -29,22 +29,21 @@ #include "wrapper/VecIndex.h" -namespace zilliz { namespace milvus { namespace scheduler { class MockVecIndex : public engine::VecIndex { -public: - virtual Status BuildAll(const long &nb, - const float *xb, - const long *ids, - const engine::Config &cfg, - const long &nt = 0, - const float *xt = nullptr) { + public: + virtual Status BuildAll(const int64_t &nb, + const float *xb, + const int64_t *ids, + const engine::Config &cfg, + const int64_t &nt = 0, + const float *xt = nullptr) { } engine::VecIndexPtr Clone() override { - return zilliz::milvus::engine::VecIndexPtr(); + return milvus::engine::VecIndexPtr(); } int64_t GetDeviceId() override { @@ -55,27 +54,23 @@ public: return engine::IndexType::INVALID; } - virtual Status Add(const long &nb, - const float *xb, - const long *ids, - const engine::Config &cfg = engine::Config()) { - + virtual Status Add(const int64_t &nb, + const float *xb, + const int64_t *ids, + const engine::Config &cfg = engine::Config()) { } - virtual Status Search(const long &nq, - const float *xq, - float *dist, - long *ids, - const engine::Config &cfg = engine::Config()) { - + virtual Status Search(const int64_t &nq, + const float *xq, + float *dist, + int64_t *ids, + const engine::Config &cfg = engine::Config()) { } engine::VecIndexPtr CopyToGpu(const int64_t &device_id, const engine::Config &cfg) override { - } engine::VecIndexPtr CopyToCpu(const engine::Config &cfg) override { - } virtual int64_t Dimension() { @@ -86,26 +81,24 @@ public: return ntotal_; } - virtual zilliz::knowhere::BinarySet Serialize() { - zilliz::knowhere::BinarySet binset; + virtual knowhere::BinarySet Serialize() { + knowhere::BinarySet binset; return binset; } - virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) { - + virtual Status Load(const knowhere::BinarySet &index_binary) { } -public: + public: int64_t dimension_ = 512; int64_t ntotal_ = 0; }; - class SchedulerTest : public testing::Test { -protected: + protected: void SetUp() override { - constexpr int64_t cache_cap = 1024*1024*1024; + constexpr int64_t cache_cap = 1024 * 1024 * 1024; cache::GpuCacheMgr::GetInstance(0)->SetCapacity(cache_cap); cache::GpuCacheMgr::GetInstance(1)->SetCapacity(cache_cap); @@ -162,7 +155,8 @@ TEST_F(SchedulerTest, ON_LOAD_COMPLETED) { insert_dummy_index_into_gpu_cache(1); for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); task->label() = std::make_shared(); tasks.push_back(task); cpu_resource_.lock()->task_table().Put(task); @@ -170,7 +164,6 @@ TEST_F(SchedulerTest, ON_LOAD_COMPLETED) { sleep(3); ASSERT_EQ(res_mgr_->GetResource(ResourceType::GPU, 1)->task_table().Size(), NUM); - } TEST_F(SchedulerTest, PUSH_TASK_TO_NEIGHBOUR_RANDOMLY_TEST) { @@ -182,7 +175,8 @@ TEST_F(SchedulerTest, PUSH_TASK_TO_NEIGHBOUR_RANDOMLY_TEST) { tasks.clear(); for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy1); + auto label = std::make_shared(); + auto task = std::make_shared(dummy1, label); task->label() = std::make_shared(); tasks.push_back(task); cpu_resource_.lock()->task_table().Put(task); @@ -193,7 +187,7 @@ TEST_F(SchedulerTest, PUSH_TASK_TO_NEIGHBOUR_RANDOMLY_TEST) { } class SchedulerTest2 : public testing::Test { -protected: + protected: void SetUp() override { ResourcePtr disk = ResourceFactory::Create("disk", "DISK", 0, true, false); @@ -243,7 +237,6 @@ protected: std::shared_ptr scheduler_; }; - TEST_F(SchedulerTest2, SPECIFIED_RESOURCE_TEST) { const uint64_t NUM = 10; std::vector> tasks; @@ -251,7 +244,8 @@ TEST_F(SchedulerTest2, SPECIFIED_RESOURCE_TEST) { dummy->location_ = "location"; for (uint64_t i = 0; i < NUM; ++i) { - std::shared_ptr task = std::make_shared(dummy); + auto label = std::make_shared(); + std::shared_ptr task = std::make_shared(dummy, label); task->label() = std::make_shared(disk_); tasks.push_back(task); disk_.lock()->task_table().Put(task); @@ -260,6 +254,6 @@ TEST_F(SchedulerTest2, SPECIFIED_RESOURCE_TEST) { // ASSERT_EQ(res_mgr_->GetResource(ResourceType::GPU, 1)->task_table().Size(), NUM); } -} -} -} +} // namespace scheduler +} // namespace milvus + diff --git a/cpp/unittest/scheduler/tasktable_test.cpp b/cpp/unittest/scheduler/test_tasktable.cpp similarity index 64% rename from cpp/unittest/scheduler/tasktable_test.cpp rename to cpp/unittest/scheduler/test_tasktable.cpp index 5c64a72351fb1409de0cabf452464a1d20f83da2..271826614d4cede85e1dede3cafb00c160897ead 100644 --- a/cpp/unittest/scheduler/tasktable_test.cpp +++ b/cpp/unittest/scheduler/test_tasktable.cpp @@ -18,53 +18,56 @@ #include "scheduler/TaskTable.h" #include "scheduler/task/TestTask.h" +#include "scheduler/tasklabel/DefaultLabel.h" #include +namespace { -using namespace zilliz::milvus::scheduler; +namespace ms = milvus::scheduler; +} // namespace /************ TaskTableBaseTest ************/ class TaskTableItemTest : public ::testing::Test { -protected: + protected: void SetUp() override { - std::vector states{ - TaskTableItemState::INVALID, - TaskTableItemState::START, - TaskTableItemState::LOADING, - TaskTableItemState::LOADED, - TaskTableItemState::EXECUTING, - TaskTableItemState::EXECUTED, - TaskTableItemState::MOVING, - TaskTableItemState::MOVED}; + std::vector states{ + ms::TaskTableItemState::INVALID, + ms::TaskTableItemState::START, + ms::TaskTableItemState::LOADING, + ms::TaskTableItemState::LOADED, + ms::TaskTableItemState::EXECUTING, + ms::TaskTableItemState::EXECUTED, + ms::TaskTableItemState::MOVING, + ms::TaskTableItemState::MOVED}; for (auto &state : states) { - auto item = std::make_shared(); + auto item = std::make_shared(); item->state = state; items_.emplace_back(item); } } - TaskTableItem default_; - std::vector items_; + ms::TaskTableItem default_; + std::vector items_; }; TEST_F(TaskTableItemTest, CONSTRUCT) { ASSERT_EQ(default_.id, 0); ASSERT_EQ(default_.task, nullptr); - ASSERT_EQ(default_.state, TaskTableItemState::INVALID); + ASSERT_EQ(default_.state, ms::TaskTableItemState::INVALID); } TEST_F(TaskTableItemTest, DESTRUCT) { - auto p_item = new TaskTableItem(); + auto p_item = new ms::TaskTableItem(); delete p_item; } TEST_F(TaskTableItemTest, IS_FINISH) { for (auto &item : items_) { - if (item->state == TaskTableItemState::EXECUTED - || item->state == TaskTableItemState::MOVED) { + if (item->state == ms::TaskTableItemState::EXECUTED + || item->state == ms::TaskTableItemState::MOVED) { ASSERT_TRUE(item->IsFinish()); } else { ASSERT_FALSE(item->IsFinish()); @@ -82,9 +85,9 @@ TEST_F(TaskTableItemTest, LOAD) { for (auto &item : items_) { auto before_state = item->state; auto ret = item->Load(); - if (before_state == TaskTableItemState::START) { + if (before_state == ms::TaskTableItemState::START) { ASSERT_TRUE(ret); - ASSERT_EQ(item->state, TaskTableItemState::LOADING); + ASSERT_EQ(item->state, ms::TaskTableItemState::LOADING); } else { ASSERT_FALSE(ret); ASSERT_EQ(item->state, before_state); @@ -96,9 +99,9 @@ TEST_F(TaskTableItemTest, LOADED) { for (auto &item : items_) { auto before_state = item->state; auto ret = item->Loaded(); - if (before_state == TaskTableItemState::LOADING) { + if (before_state == ms::TaskTableItemState::LOADING) { ASSERT_TRUE(ret); - ASSERT_EQ(item->state, TaskTableItemState::LOADED); + ASSERT_EQ(item->state, ms::TaskTableItemState::LOADED); } else { ASSERT_FALSE(ret); ASSERT_EQ(item->state, before_state); @@ -110,9 +113,9 @@ TEST_F(TaskTableItemTest, EXECUTE) { for (auto &item : items_) { auto before_state = item->state; auto ret = item->Execute(); - if (before_state == TaskTableItemState::LOADED) { + if (before_state == ms::TaskTableItemState::LOADED) { ASSERT_TRUE(ret); - ASSERT_EQ(item->state, TaskTableItemState::EXECUTING); + ASSERT_EQ(item->state, ms::TaskTableItemState::EXECUTING); } else { ASSERT_FALSE(ret); ASSERT_EQ(item->state, before_state); @@ -120,14 +123,13 @@ TEST_F(TaskTableItemTest, EXECUTE) { } } - TEST_F(TaskTableItemTest, EXECUTED) { for (auto &item : items_) { auto before_state = item->state; auto ret = item->Executed(); - if (before_state == TaskTableItemState::EXECUTING) { + if (before_state == ms::TaskTableItemState::EXECUTING) { ASSERT_TRUE(ret); - ASSERT_EQ(item->state, TaskTableItemState::EXECUTED); + ASSERT_EQ(item->state, ms::TaskTableItemState::EXECUTED); } else { ASSERT_FALSE(ret); ASSERT_EQ(item->state, before_state); @@ -139,9 +141,9 @@ TEST_F(TaskTableItemTest, MOVE) { for (auto &item : items_) { auto before_state = item->state; auto ret = item->Move(); - if (before_state == TaskTableItemState::LOADED) { + if (before_state == ms::TaskTableItemState::LOADED) { ASSERT_TRUE(ret); - ASSERT_EQ(item->state, TaskTableItemState::MOVING); + ASSERT_EQ(item->state, ms::TaskTableItemState::MOVING); } else { ASSERT_FALSE(ret); ASSERT_EQ(item->state, before_state); @@ -153,9 +155,9 @@ TEST_F(TaskTableItemTest, MOVED) { for (auto &item : items_) { auto before_state = item->state; auto ret = item->Moved(); - if (before_state == TaskTableItemState::MOVING) { + if (before_state == ms::TaskTableItemState::MOVING) { ASSERT_TRUE(ret); - ASSERT_EQ(item->state, TaskTableItemState::MOVED); + ASSERT_EQ(item->state, ms::TaskTableItemState::MOVED); } else { ASSERT_FALSE(ret); ASSERT_EQ(item->state, before_state); @@ -166,19 +168,20 @@ TEST_F(TaskTableItemTest, MOVED) { /************ TaskTableBaseTest ************/ class TaskTableBaseTest : public ::testing::Test { -protected: + protected: void SetUp() override { - TableFileSchemaPtr dummy = nullptr; + ms::TableFileSchemaPtr dummy = nullptr; invalid_task_ = nullptr; - task1_ = std::make_shared(dummy); - task2_ = std::make_shared(dummy); + auto label = std::make_shared(); + task1_ = std::make_shared(dummy, label); + task2_ = std::make_shared(dummy, label); } - TaskPtr invalid_task_; - TaskPtr task1_; - TaskPtr task2_; - TaskTable empty_table_; + ms::TaskPtr invalid_task_; + ms::TaskPtr task1_; + ms::TaskPtr task2_; + ms::TaskTable empty_table_; }; TEST_F(TaskTableBaseTest, SUBSCRIBER) { @@ -191,7 +194,6 @@ TEST_F(TaskTableBaseTest, SUBSCRIBER) { ASSERT_TRUE(flag); } - TEST_F(TaskTableBaseTest, PUT_TASK) { empty_table_.Put(task1_); ASSERT_EQ(empty_table_.Get(0)->task, task1_); @@ -203,14 +205,14 @@ TEST_F(TaskTableBaseTest, PUT_INVALID_TEST) { } TEST_F(TaskTableBaseTest, PUT_BATCH) { - std::vector tasks{task1_, task2_}; + std::vector tasks{task1_, task2_}; empty_table_.Put(tasks); ASSERT_EQ(empty_table_.Get(0)->task, task1_); ASSERT_EQ(empty_table_.Get(1)->task, task2_); } TEST_F(TaskTableBaseTest, PUT_EMPTY_BATCH) { - std::vector tasks{}; + std::vector tasks{}; empty_table_.Put(tasks); } @@ -236,8 +238,8 @@ TEST_F(TaskTableBaseTest, PICK_TO_LOAD) { for (size_t i = 0; i < NUM_TASKS; ++i) { empty_table_.Put(task1_); } - empty_table_[0]->state = TaskTableItemState::MOVED; - empty_table_[1]->state = TaskTableItemState::EXECUTED; + empty_table_[0]->state = ms::TaskTableItemState::MOVED; + empty_table_[1]->state = ms::TaskTableItemState::EXECUTED; auto indexes = empty_table_.PickToLoad(1); ASSERT_EQ(indexes.size(), 1); @@ -249,8 +251,8 @@ TEST_F(TaskTableBaseTest, PICK_TO_LOAD_LIMIT) { for (size_t i = 0; i < NUM_TASKS; ++i) { empty_table_.Put(task1_); } - empty_table_[0]->state = TaskTableItemState::MOVED; - empty_table_[1]->state = TaskTableItemState::EXECUTED; + empty_table_[0]->state = ms::TaskTableItemState::MOVED; + empty_table_[1]->state = ms::TaskTableItemState::EXECUTED; auto indexes = empty_table_.PickToLoad(3); ASSERT_EQ(indexes.size(), 3); @@ -264,8 +266,8 @@ TEST_F(TaskTableBaseTest, PICK_TO_LOAD_CACHE) { for (size_t i = 0; i < NUM_TASKS; ++i) { empty_table_.Put(task1_); } - empty_table_[0]->state = TaskTableItemState::MOVED; - empty_table_[1]->state = TaskTableItemState::EXECUTED; + empty_table_[0]->state = ms::TaskTableItemState::MOVED; + empty_table_[1]->state = ms::TaskTableItemState::EXECUTED; // first pick, non-cache auto indexes = empty_table_.PickToLoad(1); @@ -274,7 +276,7 @@ TEST_F(TaskTableBaseTest, PICK_TO_LOAD_CACHE) { // second pick, iterate from 2 // invalid state change - empty_table_[1]->state = TaskTableItemState::START; + empty_table_[1]->state = ms::TaskTableItemState::START; indexes = empty_table_.PickToLoad(1); ASSERT_EQ(indexes.size(), 1); ASSERT_EQ(indexes[0], 2); @@ -285,9 +287,9 @@ TEST_F(TaskTableBaseTest, PICK_TO_EXECUTE) { for (size_t i = 0; i < NUM_TASKS; ++i) { empty_table_.Put(task1_); } - empty_table_[0]->state = TaskTableItemState::MOVED; - empty_table_[1]->state = TaskTableItemState::EXECUTED; - empty_table_[2]->state = TaskTableItemState::LOADED; + empty_table_[0]->state = ms::TaskTableItemState::MOVED; + empty_table_[1]->state = ms::TaskTableItemState::EXECUTED; + empty_table_[2]->state = ms::TaskTableItemState::LOADED; auto indexes = empty_table_.PickToExecute(1); ASSERT_EQ(indexes.size(), 1); @@ -299,10 +301,10 @@ TEST_F(TaskTableBaseTest, PICK_TO_EXECUTE_LIMIT) { for (size_t i = 0; i < NUM_TASKS; ++i) { empty_table_.Put(task1_); } - empty_table_[0]->state = TaskTableItemState::MOVED; - empty_table_[1]->state = TaskTableItemState::EXECUTED; - empty_table_[2]->state = TaskTableItemState::LOADED; - empty_table_[3]->state = TaskTableItemState::LOADED; + empty_table_[0]->state = ms::TaskTableItemState::MOVED; + empty_table_[1]->state = ms::TaskTableItemState::EXECUTED; + empty_table_[2]->state = ms::TaskTableItemState::LOADED; + empty_table_[3]->state = ms::TaskTableItemState::LOADED; auto indexes = empty_table_.PickToExecute(3); ASSERT_EQ(indexes.size(), 2); @@ -315,9 +317,9 @@ TEST_F(TaskTableBaseTest, PICK_TO_EXECUTE_CACHE) { for (size_t i = 0; i < NUM_TASKS; ++i) { empty_table_.Put(task1_); } - empty_table_[0]->state = TaskTableItemState::MOVED; - empty_table_[1]->state = TaskTableItemState::EXECUTED; - empty_table_[2]->state = TaskTableItemState::LOADED; + empty_table_[0]->state = ms::TaskTableItemState::MOVED; + empty_table_[1]->state = ms::TaskTableItemState::EXECUTED; + empty_table_[2]->state = ms::TaskTableItemState::LOADED; // first pick, non-cache auto indexes = empty_table_.PickToExecute(1); @@ -326,40 +328,40 @@ TEST_F(TaskTableBaseTest, PICK_TO_EXECUTE_CACHE) { // second pick, iterate from 2 // invalid state change - empty_table_[1]->state = TaskTableItemState::START; + empty_table_[1]->state = ms::TaskTableItemState::START; indexes = empty_table_.PickToExecute(1); ASSERT_EQ(indexes.size(), 1); ASSERT_EQ(indexes[0], 2); } - /************ TaskTableAdvanceTest ************/ class TaskTableAdvanceTest : public ::testing::Test { -protected: + protected: void SetUp() override { - TableFileSchemaPtr dummy = nullptr; + ms::TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < 8; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); table1_.Put(task); } - table1_.Get(0)->state = TaskTableItemState::INVALID; - table1_.Get(1)->state = TaskTableItemState::START; - table1_.Get(2)->state = TaskTableItemState::LOADING; - table1_.Get(3)->state = TaskTableItemState::LOADED; - table1_.Get(4)->state = TaskTableItemState::EXECUTING; - table1_.Get(5)->state = TaskTableItemState::EXECUTED; - table1_.Get(6)->state = TaskTableItemState::MOVING; - table1_.Get(7)->state = TaskTableItemState::MOVED; + table1_.Get(0)->state = ms::TaskTableItemState::INVALID; + table1_.Get(1)->state = ms::TaskTableItemState::START; + table1_.Get(2)->state = ms::TaskTableItemState::LOADING; + table1_.Get(3)->state = ms::TaskTableItemState::LOADED; + table1_.Get(4)->state = ms::TaskTableItemState::EXECUTING; + table1_.Get(5)->state = ms::TaskTableItemState::EXECUTED; + table1_.Get(6)->state = ms::TaskTableItemState::MOVING; + table1_.Get(7)->state = ms::TaskTableItemState::MOVED; } - TaskTable table1_; + ms::TaskTable table1_; }; TEST_F(TaskTableAdvanceTest, LOAD) { - std::vector before_state; + std::vector before_state; for (auto &task : table1_) { before_state.push_back(task->state); } @@ -369,8 +371,8 @@ TEST_F(TaskTableAdvanceTest, LOAD) { } for (size_t i = 0; i < table1_.Size(); ++i) { - if (before_state[i] == TaskTableItemState::START) { - ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::LOADING); + if (before_state[i] == ms::TaskTableItemState::START) { + ASSERT_EQ(table1_.Get(i)->state, ms::TaskTableItemState::LOADING); } else { ASSERT_EQ(table1_.Get(i)->state, before_state[i]); } @@ -378,7 +380,7 @@ TEST_F(TaskTableAdvanceTest, LOAD) { } TEST_F(TaskTableAdvanceTest, LOADED) { - std::vector before_state; + std::vector before_state; for (auto &task : table1_) { before_state.push_back(task->state); } @@ -388,8 +390,8 @@ TEST_F(TaskTableAdvanceTest, LOADED) { } for (size_t i = 0; i < table1_.Size(); ++i) { - if (before_state[i] == TaskTableItemState::LOADING) { - ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::LOADED); + if (before_state[i] == ms::TaskTableItemState::LOADING) { + ASSERT_EQ(table1_.Get(i)->state, ms::TaskTableItemState::LOADED); } else { ASSERT_EQ(table1_.Get(i)->state, before_state[i]); } @@ -397,7 +399,7 @@ TEST_F(TaskTableAdvanceTest, LOADED) { } TEST_F(TaskTableAdvanceTest, EXECUTE) { - std::vector before_state; + std::vector before_state; for (auto &task : table1_) { before_state.push_back(task->state); } @@ -407,8 +409,8 @@ TEST_F(TaskTableAdvanceTest, EXECUTE) { } for (size_t i = 0; i < table1_.Size(); ++i) { - if (before_state[i] == TaskTableItemState::LOADED) { - ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::EXECUTING); + if (before_state[i] == ms::TaskTableItemState::LOADED) { + ASSERT_EQ(table1_.Get(i)->state, ms::TaskTableItemState::EXECUTING); } else { ASSERT_EQ(table1_.Get(i)->state, before_state[i]); } @@ -416,7 +418,7 @@ TEST_F(TaskTableAdvanceTest, EXECUTE) { } TEST_F(TaskTableAdvanceTest, EXECUTED) { - std::vector before_state; + std::vector before_state; for (auto &task : table1_) { before_state.push_back(task->state); } @@ -426,8 +428,8 @@ TEST_F(TaskTableAdvanceTest, EXECUTED) { } for (size_t i = 0; i < table1_.Size(); ++i) { - if (before_state[i] == TaskTableItemState::EXECUTING) { - ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::EXECUTED); + if (before_state[i] == ms::TaskTableItemState::EXECUTING) { + ASSERT_EQ(table1_.Get(i)->state, ms::TaskTableItemState::EXECUTED); } else { ASSERT_EQ(table1_.Get(i)->state, before_state[i]); } @@ -435,7 +437,7 @@ TEST_F(TaskTableAdvanceTest, EXECUTED) { } TEST_F(TaskTableAdvanceTest, MOVE) { - std::vector before_state; + std::vector before_state; for (auto &task : table1_) { before_state.push_back(task->state); } @@ -445,8 +447,8 @@ TEST_F(TaskTableAdvanceTest, MOVE) { } for (size_t i = 0; i < table1_.Size(); ++i) { - if (before_state[i] == TaskTableItemState::LOADED) { - ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::MOVING); + if (before_state[i] == ms::TaskTableItemState::LOADED) { + ASSERT_EQ(table1_.Get(i)->state, ms::TaskTableItemState::MOVING); } else { ASSERT_EQ(table1_.Get(i)->state, before_state[i]); } @@ -454,7 +456,7 @@ TEST_F(TaskTableAdvanceTest, MOVE) { } TEST_F(TaskTableAdvanceTest, MOVED) { - std::vector before_state; + std::vector before_state; for (auto &task : table1_) { before_state.push_back(task->state); } @@ -464,8 +466,8 @@ TEST_F(TaskTableAdvanceTest, MOVED) { } for (size_t i = 0; i < table1_.Size(); ++i) { - if (before_state[i] == TaskTableItemState::MOVING) { - ASSERT_EQ(table1_.Get(i)->state, TaskTableItemState::MOVED); + if (before_state[i] == ms::TaskTableItemState::MOVING) { + ASSERT_EQ(table1_.Get(i)->state, ms::TaskTableItemState::MOVED); } else { ASSERT_EQ(table1_.Get(i)->state, before_state[i]); } diff --git a/cpp/unittest/server/CMakeLists.txt b/cpp/unittest/server/CMakeLists.txt index d5bbd6cc47f5783f89b4ffac7c30b3c5788a36e5..4420e2a1a7777c6e9261799f7a0a7b2abfe1fc79 100644 --- a/cpp/unittest/server/CMakeLists.txt +++ b/cpp/unittest/server/CMakeLists.txt @@ -43,11 +43,11 @@ set(server_test_files ${grpc_server_files} ${grpc_service_files} ${util_files} - ${unittest_files} + ${entry_file} ${test_files} ) -cuda_add_executable(server_test ${server_test_files}) +cuda_add_executable(test_server ${server_test_files}) set(client_grpc_lib grpcpp_channelz @@ -56,7 +56,7 @@ set(client_grpc_lib grpc_protobuf grpc_protoc) -target_link_libraries(server_test +target_link_libraries(test_server knowhere stdc++ snappy @@ -66,7 +66,7 @@ target_link_libraries(server_test ${unittest_libs} ) -install(TARGETS server_test DESTINATION unittest) +install(TARGETS test_server DESTINATION unittest) configure_file(appendix/server_config.yaml "${CMAKE_CURRENT_BINARY_DIR}/milvus/conf/server_config.yaml" diff --git a/cpp/unittest/server/cache_test.cpp b/cpp/unittest/server/test_cache.cpp similarity index 56% rename from cpp/unittest/server/cache_test.cpp rename to cpp/unittest/server/test_cache.cpp index 290caf4a970633e6bfa1dadb0ca384451fefc3bf..754f23f51be8548e178322df3e2ce06f0587e292 100644 --- a/cpp/unittest/server/cache_test.cpp +++ b/cpp/unittest/server/test_cache.cpp @@ -21,73 +21,72 @@ #include "utils/Error.h" #include "wrapper/VecIndex.h" -using namespace zilliz::milvus; - namespace { -class InvalidCacheMgr : public cache::CacheMgr { -public: +namespace ms = milvus; + +class InvalidCacheMgr : public ms::cache::CacheMgr { + public: InvalidCacheMgr() { } }; -class LessItemCacheMgr : public cache::CacheMgr { -public: +class LessItemCacheMgr : public ms::cache::CacheMgr { + public: LessItemCacheMgr() { - cache_ = std::make_shared>(1UL << 12, 10); + cache_ = std::make_shared>(1UL << 12, 10); } }; -class MockVecIndex : public engine::VecIndex { -public: +class MockVecIndex : public ms::engine::VecIndex { + public: MockVecIndex(int64_t dim, int64_t total) : dimension_(dim), - ntotal_(total){ - + ntotal_(total) { } - virtual Status BuildAll(const long &nb, - const float *xb, - const long *ids, - const engine::Config &cfg, - const long &nt = 0, - const float *xt = nullptr) { - return Status(); + virtual ms::Status BuildAll(const int64_t &nb, + const float *xb, + const int64_t *ids, + const ms::engine::Config &cfg, + const int64_t &nt = 0, + const float *xt = nullptr) { + return ms::Status(); } - engine::VecIndexPtr Clone() override { - return zilliz::milvus::engine::VecIndexPtr(); + ms::engine::VecIndexPtr Clone() override { + return milvus::engine::VecIndexPtr(); } int64_t GetDeviceId() override { return 0; } - engine::IndexType GetType() override { - return engine::IndexType::INVALID; + ms::engine::IndexType GetType() override { + return ms::engine::IndexType::INVALID; } - virtual Status Add(const long &nb, - const float *xb, - const long *ids, - const engine::Config &cfg = engine::Config()) { - return Status(); + virtual ms::Status Add(const int64_t &nb, + const float *xb, + const int64_t *ids, + const ms::engine::Config &cfg = ms::engine::Config()) { + return ms::Status(); } - virtual Status Search(const long &nq, - const float *xq, - float *dist, - long *ids, - const engine::Config &cfg = engine::Config()) { - return Status(); + virtual ms::Status Search(const int64_t &nq, + const float *xq, + float *dist, + int64_t *ids, + const ms::engine::Config &cfg = ms::engine::Config()) { + return ms::Status(); } - engine::VecIndexPtr CopyToGpu(const int64_t &device_id, - const engine::Config &cfg) override { + ms::engine::VecIndexPtr CopyToGpu(const int64_t &device_id, + const ms::engine::Config &cfg) override { return nullptr; } - engine::VecIndexPtr CopyToCpu(const engine::Config &cfg) override { + ms::engine::VecIndexPtr CopyToCpu(const ms::engine::Config &cfg) override { return nullptr; } @@ -99,24 +98,24 @@ public: return ntotal_; } - virtual zilliz::knowhere::BinarySet Serialize() { - zilliz::knowhere::BinarySet binset; + virtual knowhere::BinarySet Serialize() { + knowhere::BinarySet binset; return binset; } - virtual Status Load(const zilliz::knowhere::BinarySet &index_binary) { - return Status(); + virtual ms::Status Load(const knowhere::BinarySet &index_binary) { + return ms::Status(); } -public: + public: int64_t dimension_ = 256; int64_t ntotal_ = 0; }; -} +} // namespace TEST(CacheTest, DUMMY_TEST) { - engine::Config cfg; + ms::engine::Config cfg; MockVecIndex mock_index(256, 1000); mock_index.Dimension(); mock_index.Count(); @@ -128,15 +127,15 @@ TEST(CacheTest, DUMMY_TEST) { mock_index.CopyToGpu(1, cfg); mock_index.GetDeviceId(); mock_index.GetType(); - zilliz::knowhere::BinarySet index_binary; + knowhere::BinarySet index_binary; mock_index.Load(index_binary); mock_index.Serialize(); } TEST(CacheTest, CPU_CACHE_TEST) { - auto cpu_mgr = cache::CpuCacheMgr::GetInstance(); + auto cpu_mgr = ms::cache::CpuCacheMgr::GetInstance(); - const int64_t gbyte = 1024*1024*1024; + const int64_t gbyte = 1024 * 1024 * 1024; int64_t g_num = 16; int64_t cap = g_num * gbyte; cpu_mgr->SetCapacity(cap); @@ -145,8 +144,8 @@ TEST(CacheTest, CPU_CACHE_TEST) { uint64_t item_count = 20; for (uint64_t i = 0; i < item_count; i++) { //each vector is 1k byte, total size less than 1G - engine::VecIndexPtr mock_index = std::make_shared(256, 1000000); - cache::DataObjPtr data_obj = std::make_shared(mock_index); + ms::engine::VecIndexPtr mock_index = std::make_shared(256, 1000000); + ms::cache::DataObjPtr data_obj = std::make_shared(mock_index); cpu_mgr->InsertItem("index_" + std::to_string(i), data_obj); } ASSERT_LT(cpu_mgr->ItemCount(), g_num); @@ -169,8 +168,8 @@ TEST(CacheTest, CPU_CACHE_TEST) { cpu_mgr->SetCapacity(g_num * gbyte); //each vector is 1k byte, total size less than 6G - engine::VecIndexPtr mock_index = std::make_shared(256, 6000000); - cache::DataObjPtr data_obj = std::make_shared(mock_index); + ms::engine::VecIndexPtr mock_index = std::make_shared(256, 6000000); + ms::cache::DataObjPtr data_obj = std::make_shared(mock_index); cpu_mgr->InsertItem("index_6g", data_obj); ASSERT_TRUE(cpu_mgr->ItemExists("index_6g")); } @@ -179,12 +178,12 @@ TEST(CacheTest, CPU_CACHE_TEST) { } TEST(CacheTest, GPU_CACHE_TEST) { - auto gpu_mgr = cache::GpuCacheMgr::GetInstance(0); + auto gpu_mgr = ms::cache::GpuCacheMgr::GetInstance(0); - for(int i = 0; i < 20; i++) { + for (int i = 0; i < 20; i++) { //each vector is 1k byte - engine::VecIndexPtr mock_index = std::make_shared(256, 1000); - cache::DataObjPtr data_obj = std::make_shared(mock_index); + ms::engine::VecIndexPtr mock_index = std::make_shared(256, 1000); + ms::cache::DataObjPtr data_obj = std::make_shared(mock_index); gpu_mgr->InsertItem("index_" + std::to_string(i), data_obj); } @@ -194,17 +193,16 @@ TEST(CacheTest, GPU_CACHE_TEST) { ASSERT_EQ(gpu_mgr->ItemCount(), 0); for (auto i = 0; i < 3; i++) { - // TODO: use gpu index to mock + // TODO(myh): use gpu index to mock //each vector is 1k byte, total size less than 2G - engine::VecIndexPtr mock_index = std::make_shared(256, 2000000); - cache::DataObjPtr data_obj = std::make_shared(mock_index); - std::cout << data_obj->size() <(256, 2000000); + ms::cache::DataObjPtr data_obj = std::make_shared(mock_index); + std::cout << data_obj->size() << std::endl; gpu_mgr->InsertItem("index_" + std::to_string(i), data_obj); } gpu_mgr->ClearCache(); ASSERT_EQ(gpu_mgr->ItemCount(), 0); - } TEST(CacheTest, INVALID_TEST) { @@ -214,7 +212,7 @@ TEST(CacheTest, INVALID_TEST) { ASSERT_FALSE(mgr.ItemExists("test")); ASSERT_EQ(mgr.GetItem("test"), nullptr); - mgr.InsertItem("test", cache::DataObjPtr()); + mgr.InsertItem("test", ms::cache::DataObjPtr()); mgr.InsertItem("test", nullptr); mgr.EraseItem("test"); mgr.PrintInfo(); @@ -226,12 +224,12 @@ TEST(CacheTest, INVALID_TEST) { { LessItemCacheMgr mgr; - for(int i = 0; i < 20; i++) { + for (int i = 0; i < 20; i++) { //each vector is 1k byte - engine::VecIndexPtr mock_index = std::make_shared(256, 2); - cache::DataObjPtr data_obj = std::make_shared(mock_index); + ms::engine::VecIndexPtr mock_index = std::make_shared(256, 2); + ms::cache::DataObjPtr data_obj = std::make_shared(mock_index); mgr.InsertItem("index_" + std::to_string(i), data_obj); } ASSERT_EQ(mgr.GetItem("index_0"), nullptr); } -} \ No newline at end of file +} diff --git a/cpp/unittest/server/config_test.cpp b/cpp/unittest/server/test_config.cpp similarity index 68% rename from cpp/unittest/server/config_test.cpp rename to cpp/unittest/server/test_config.cpp index b27d1e60c818967d8e768866e4f73c0428881dbe..ef708a8382e3949875d455b8f803eadcfd4c90d6 100644 --- a/cpp/unittest/server/config_test.cpp +++ b/cpp/unittest/server/test_config.cpp @@ -23,39 +23,39 @@ #include "utils/ValidationUtil.h" #include "server/Config.h" -using namespace zilliz::milvus; - namespace { -static const char* CONFIG_FILE_PATH = "./milvus/conf/server_config.yaml"; -static const char* LOG_FILE_PATH = "./milvus/conf/log_config.conf"; +namespace ms = milvus; + +static const char *CONFIG_FILE_PATH = "./milvus/conf/server_config.yaml"; +static const char *LOG_FILE_PATH = "./milvus/conf/log_config.conf"; static constexpr uint64_t KB = 1024; -static constexpr uint64_t MB = KB*1024; -static constexpr uint64_t GB = MB*1024; +static constexpr uint64_t MB = KB * 1024; +static constexpr uint64_t GB = MB * 1024; -} +} // namespace TEST(ConfigTest, CONFIG_TEST) { - server::ConfigMgr* config_mgr = server::ConfigMgr::GetInstance(); + ms::server::ConfigMgr *config_mgr = ms::server::ConfigMgr::GetInstance(); - ErrorCode err = config_mgr->LoadConfigFile(""); - ASSERT_EQ(err, SERVER_UNEXPECTED_ERROR); + ms::ErrorCode err = config_mgr->LoadConfigFile(""); + ASSERT_EQ(err, ms::SERVER_UNEXPECTED_ERROR); err = config_mgr->LoadConfigFile(LOG_FILE_PATH); - ASSERT_EQ(err, SERVER_UNEXPECTED_ERROR); + ASSERT_EQ(err, ms::SERVER_UNEXPECTED_ERROR); err = config_mgr->LoadConfigFile(CONFIG_FILE_PATH); - ASSERT_EQ(err, SERVER_SUCCESS); + ASSERT_EQ(err, ms::SERVER_SUCCESS); config_mgr->Print(); - server::ConfigNode& root_config = config_mgr->GetRootNode(); - server::ConfigNode& server_config = root_config.GetChild("server_config"); - server::ConfigNode& db_config = root_config.GetChild("db_config"); - server::ConfigNode& metric_config = root_config.GetChild("metric_config"); - server::ConfigNode& cache_config = root_config.GetChild("cache_config"); - server::ConfigNode invalid_config = root_config.GetChild("invalid_config"); + ms::server::ConfigNode &root_config = config_mgr->GetRootNode(); + ms::server::ConfigNode &server_config = root_config.GetChild("server_config"); + ms::server::ConfigNode &db_config = root_config.GetChild("db_config"); + ms::server::ConfigNode &metric_config = root_config.GetChild("metric_config"); + ms::server::ConfigNode &cache_config = root_config.GetChild("cache_config"); + ms::server::ConfigNode invalid_config = root_config.GetChild("invalid_config"); auto valus = invalid_config.GetSequence("not_exist"); float ff = invalid_config.GetFloatValue("not_exist", 3.0); ASSERT_EQ(ff, 3.0); @@ -63,16 +63,16 @@ TEST(ConfigTest, CONFIG_TEST) { std::string address = server_config.GetValue("address"); ASSERT_TRUE(!address.empty()); int64_t port = server_config.GetInt64Value("port"); - ASSERT_TRUE(port != 0); + ASSERT_NE(port, 0); server_config.SetValue("test", "2.5"); double test = server_config.GetDoubleValue("test"); ASSERT_EQ(test, 2.5); - server::ConfigNode fake; + ms::server::ConfigNode fake; server_config.AddChild("fake", fake); fake = server_config.GetChild("fake"); - server::ConfigNodeArr arr; + ms::server::ConfigNodeArr arr; server_config.GetChildren(arr); ASSERT_EQ(arr.size(), 1UL); @@ -89,7 +89,7 @@ TEST(ConfigTest, CONFIG_TEST) { auto seq = server_config.GetSequence("seq"); ASSERT_EQ(seq.size(), 2UL); - server::ConfigNode combine; + ms::server::ConfigNode combine; combine.Combine(server_config); combine.PrintAll(); @@ -102,8 +102,8 @@ TEST(ConfigTest, CONFIG_TEST) { } TEST(ConfigTest, SERVER_CONFIG_TEST) { - server::Config& config = server::Config::GetInstance(); - Status s = config.LoadConfigFile(CONFIG_FILE_PATH); + ms::server::Config &config = ms::server::Config::GetInstance(); + ms::Status s = config.LoadConfigFile(CONFIG_FILE_PATH); ASSERT_TRUE(s.ok()); s = config.ValidateConfig(); @@ -113,4 +113,4 @@ TEST(ConfigTest, SERVER_CONFIG_TEST) { s = config.ResetDefaultConfig(); ASSERT_TRUE(s.ok()); -} \ No newline at end of file +} diff --git a/cpp/unittest/server/rpc_test.cpp b/cpp/unittest/server/test_rpc.cpp similarity index 82% rename from cpp/unittest/server/rpc_test.cpp rename to cpp/unittest/server/test_rpc.cpp index 0eaadf61465d7b726db0a40814524fbf23a6163e..b847ec31167c1b5ec039072c3066cd374bfebe14 100644 --- a/cpp/unittest/server/rpc_test.cpp +++ b/cpp/unittest/server/test_rpc.cpp @@ -23,7 +23,7 @@ #include "server/grpc_impl/GrpcRequestHandler.h" #include "server/grpc_impl/GrpcRequestScheduler.h" #include "server/grpc_impl/GrpcRequestTask.h" -#include "version.h" +#include "../version.h" #include "grpc/gen-milvus/milvus.grpc.pb.h" #include "grpc/gen-status/status.pb.h" @@ -34,10 +34,10 @@ #include "scheduler/ResourceFactory.h" #include "utils/CommonUtil.h" -using namespace zilliz::milvus; - namespace { +namespace ms = milvus; + static const char *TABLE_NAME = "test_grpc"; static constexpr int64_t TABLE_DIM = 256; static constexpr int64_t INDEX_FILE_SIZE = 1024; @@ -49,30 +49,29 @@ class RpcHandlerTest : public testing::Test { protected: void SetUp() override { - - auto res_mgr = scheduler::ResMgrInst::GetInstance(); + auto res_mgr = ms::scheduler::ResMgrInst::GetInstance(); res_mgr->Clear(); - res_mgr->Add(scheduler::ResourceFactory::Create("disk", "DISK", 0, true, false)); - res_mgr->Add(scheduler::ResourceFactory::Create("cpu", "CPU", 0, true, true)); - res_mgr->Add(scheduler::ResourceFactory::Create("gtx1660", "GPU", 0, true, true)); + res_mgr->Add(ms::scheduler::ResourceFactory::Create("disk", "DISK", 0, true, false)); + res_mgr->Add(ms::scheduler::ResourceFactory::Create("cpu", "CPU", 0, true, true)); + res_mgr->Add(ms::scheduler::ResourceFactory::Create("gtx1660", "GPU", 0, true, true)); - auto default_conn = scheduler::Connection("IO", 500.0); - auto PCIE = scheduler::Connection("IO", 11000.0); + auto default_conn = ms::scheduler::Connection("IO", 500.0); + auto PCIE = ms::scheduler::Connection("IO", 11000.0); res_mgr->Connect("disk", "cpu", default_conn); res_mgr->Connect("cpu", "gtx1660", PCIE); res_mgr->Start(); - scheduler::SchedInst::GetInstance()->Start(); - scheduler::JobMgrInst::GetInstance()->Start(); + ms::scheduler::SchedInst::GetInstance()->Start(); + ms::scheduler::JobMgrInst::GetInstance()->Start(); - engine::DBOptions opt; + ms::engine::DBOptions opt; - server::Config::GetInstance().SetDBConfigBackendUrl("sqlite://:@:/"); - server::Config::GetInstance().SetDBConfigPrimaryPath("/tmp/milvus_test"); - server::Config::GetInstance().SetDBConfigSecondaryPath(""); - server::Config::GetInstance().SetDBConfigArchiveDiskThreshold(""); - server::Config::GetInstance().SetDBConfigArchiveDaysThreshold(""); - server::Config::GetInstance().SetCacheConfigCacheInsertData(""); - server::Config::GetInstance().SetEngineConfigOmpThreadNum(""); + ms::server::Config::GetInstance().SetDBConfigBackendUrl("sqlite://:@:/"); + ms::server::Config::GetInstance().SetDBConfigPrimaryPath("/tmp/milvus_test"); + ms::server::Config::GetInstance().SetDBConfigSecondaryPath(""); + ms::server::Config::GetInstance().SetDBConfigArchiveDiskThreshold(""); + ms::server::Config::GetInstance().SetDBConfigArchiveDaysThreshold(""); + ms::server::Config::GetInstance().SetCacheConfigCacheInsertData(""); + ms::server::Config::GetInstance().SetEngineConfigOmpThreadNum(""); // serverConfig.SetValue(server::CONFIG_CLUSTER_MODE, "cluster"); // DBWrapper::GetInstance().GetInstance().StartService(); @@ -82,11 +81,11 @@ class RpcHandlerTest : public testing::Test { // DBWrapper::GetInstance().GetInstance().StartService(); // DBWrapper::GetInstance().GetInstance().StopService(); - server::Config::GetInstance().SetResourceConfigMode("single"); - server::DBWrapper::GetInstance().StartService(); + ms::server::Config::GetInstance().SetResourceConfigMode("single"); + ms::server::DBWrapper::GetInstance().StartService(); //initialize handler, create table - handler = std::make_shared(); + handler = std::make_shared(); ::grpc::ServerContext context; ::milvus::grpc::TableSchema request; ::milvus::grpc::Status status; @@ -99,18 +98,20 @@ class RpcHandlerTest : public testing::Test { void TearDown() override { - server::DBWrapper::GetInstance().StopService(); - scheduler::JobMgrInst::GetInstance()->Stop(); - scheduler::ResMgrInst::GetInstance()->Stop(); - scheduler::SchedInst::GetInstance()->Stop(); + ms::server::DBWrapper::GetInstance().StopService(); + ms::scheduler::JobMgrInst::GetInstance()->Stop(); + ms::scheduler::ResMgrInst::GetInstance()->Stop(); + ms::scheduler::SchedInst::GetInstance()->Stop(); boost::filesystem::remove_all("/tmp/milvus_test"); } + protected: - std::shared_ptr handler; + std::shared_ptr handler; }; -void BuildVectors(int64_t from, int64_t to, - std::vector> &vector_record_array) { +void +BuildVectors(int64_t from, int64_t to, + std::vector> &vector_record_array) { if (to <= from) { return; } @@ -127,20 +128,22 @@ void BuildVectors(int64_t from, int64_t to, } } -std::string CurrentTmDate(int64_t offset_day = 0) { +std::string +CurrentTmDate(int64_t offset_day = 0) { time_t tt; time(&tt); tt = tt + 8 * SECONDS_EACH_HOUR; tt = tt + 24 * SECONDS_EACH_HOUR * offset_day; - tm *t = gmtime(&tt); + tm t; + gmtime_r(&tt, &t); - std::string str = std::to_string(t->tm_year + 1900) + "-" + std::to_string(t->tm_mon + 1) - + "-" + std::to_string(t->tm_mday); + std::string str = std::to_string(t.tm_year + 1900) + "-" + std::to_string(t.tm_mon + 1) + + "-" + std::to_string(t.tm_mday); return str; } -} +} // namespace TEST_F(RpcHandlerTest, HAS_TABLE_TEST) { ::grpc::ServerContext context; @@ -311,7 +314,6 @@ TEST_F(RpcHandlerTest, TABLES_TEST) { ::grpc::Status status = handler->DescribeTable(&context, &table_name, &table_schema); ASSERT_EQ(status.error_code(), ::grpc::Status::OK.error_code()); - ::milvus::grpc::InsertParam request; std::vector> record_array; BuildVectors(0, VECTOR_COUNT, record_array); @@ -349,11 +351,11 @@ TEST_F(RpcHandlerTest, TABLES_TEST) { handler->Insert(&context, &request, &vector_ids); -//Show table -// ::milvus::grpc::Command cmd; -// ::grpc::ServerWriter<::milvus::grpc::TableName> *writer; -// status = handler->ShowTables(&context, &cmd, writer); -// ASSERT_EQ(status.error_code(), ::grpc::Status::OK.error_code()); + //show tables + ::milvus::grpc::Command cmd; + ::milvus::grpc::TableNameList table_name_list; + status = handler->ShowTables(&context, &cmd, &table_name_list); + ASSERT_EQ(status.error_code(), ::grpc::Status::OK.error_code()); //Count Table ::milvus::grpc::TableRowCount count; @@ -419,31 +421,29 @@ TEST_F(RpcHandlerTest, DELETE_BY_RANGE_TEST) { grpc_status = handler->DeleteByRange(&context, &request, &status); request.mutable_range()->set_end_value(CurrentTmDate(-2)); grpc_status = handler->DeleteByRange(&context, &request, &status); - } ////////////////////////////////////////////////////////////////////// namespace { -class DummyTask : public server::grpc::GrpcBaseTask { -public: - Status +class DummyTask : public ms::server::grpc::GrpcBaseTask { + public: + ms::Status OnExecute() override { - return Status::OK(); + return ms::Status::OK(); } - static server::grpc::BaseTaskPtr + static ms::server::grpc::BaseTaskPtr Create(std::string &dummy) { - return std::shared_ptr(new DummyTask(dummy)); + return std::shared_ptr(new DummyTask(dummy)); } -public: + public: explicit DummyTask(std::string &dummy) : GrpcBaseTask(dummy) { - } }; class RpcSchedulerTest : public testing::Test { -protected: + protected: void SetUp() override { std::string dummy = "dql"; @@ -453,22 +453,22 @@ protected: std::shared_ptr task_ptr; }; -} +} // namespace -TEST_F(RpcSchedulerTest, BASE_TASK_TEST){ +TEST_F(RpcSchedulerTest, BASE_TASK_TEST) { auto status = task_ptr->Execute(); ASSERT_TRUE(status.ok()); - server::grpc::GrpcRequestScheduler::GetInstance().Start(); + ms::server::grpc::GrpcRequestScheduler::GetInstance().Start(); ::milvus::grpc::Status grpc_status; std::string dummy = "dql"; - server::grpc::BaseTaskPtr base_task_ptr = DummyTask::Create(dummy); - server::grpc::GrpcRequestScheduler::GetInstance().ExecTask(base_task_ptr, &grpc_status); + ms::server::grpc::BaseTaskPtr base_task_ptr = DummyTask::Create(dummy); + ms::server::grpc::GrpcRequestScheduler::GetInstance().ExecTask(base_task_ptr, &grpc_status); - server::grpc::GrpcRequestScheduler::GetInstance().ExecuteTask(task_ptr); + ms::server::grpc::GrpcRequestScheduler::GetInstance().ExecuteTask(task_ptr); task_ptr = nullptr; - server::grpc::GrpcRequestScheduler::GetInstance().ExecuteTask(task_ptr); + ms::server::grpc::GrpcRequestScheduler::GetInstance().ExecuteTask(task_ptr); - server::grpc::GrpcRequestScheduler::GetInstance().Stop(); + ms::server::grpc::GrpcRequestScheduler::GetInstance().Stop(); } diff --git a/cpp/unittest/server/util_test.cpp b/cpp/unittest/server/util_test.cpp index 1757a7ddfea67b35b42ea98763d396f248dfe689..0ee214a7e7c06787ce822d4a58cb47543c0f5289 100644 --- a/cpp/unittest/server/util_test.cpp +++ b/cpp/unittest/server/util_test.cpp @@ -31,75 +31,76 @@ #include #include -using namespace zilliz::milvus; - namespace { -static const char* LOG_FILE_PATH = "./milvus/conf/log_config.conf"; +namespace ms = milvus; + +static const char *LOG_FILE_PATH = "./milvus/conf/log_config.conf"; -void CopyStatus(Status& st1, Status& st2) { +void +CopyStatus(ms::Status &st1, ms::Status &st2) { st1 = st2; } -} +} // namespace TEST(UtilTest, EXCEPTION_TEST) { std::string err_msg = "failed"; - server::ServerException ex(SERVER_UNEXPECTED_ERROR, err_msg); - ASSERT_EQ(ex.error_code(), SERVER_UNEXPECTED_ERROR); + ms::server::ServerException ex(ms::SERVER_UNEXPECTED_ERROR, err_msg); + ASSERT_EQ(ex.error_code(), ms::SERVER_UNEXPECTED_ERROR); std::string msg = ex.what(); ASSERT_EQ(msg, err_msg); } TEST(UtilTest, SIGNAL_TEST) { - server::SignalUtil::PrintStacktrace(); + ms::server::SignalUtil::PrintStacktrace(); } TEST(UtilTest, COMMON_TEST) { - unsigned long total_mem = 0, free_mem = 0; - server::CommonUtil::GetSystemMemInfo(total_mem, free_mem); + uint64_t total_mem = 0, free_mem = 0; + ms::server::CommonUtil::GetSystemMemInfo(total_mem, free_mem); ASSERT_GT(total_mem, 0); ASSERT_GT(free_mem, 0); - unsigned int thread_cnt = 0; - server::CommonUtil::GetSystemAvailableThreads(thread_cnt); + uint32_t thread_cnt = 0; + ms::server::CommonUtil::GetSystemAvailableThreads(thread_cnt); ASSERT_GT(thread_cnt, 0); std::string path1 = "/tmp/milvus_test/"; std::string path2 = path1 + "common_test_12345/"; std::string path3 = path2 + "abcdef"; - Status status = server::CommonUtil::CreateDirectory(path3); + ms::Status status = ms::server::CommonUtil::CreateDirectory(path3); ASSERT_TRUE(status.ok()); //test again - status = server::CommonUtil::CreateDirectory(path3); + status = ms::server::CommonUtil::CreateDirectory(path3); ASSERT_TRUE(status.ok()); - ASSERT_TRUE(server::CommonUtil::IsDirectoryExist(path3)); + ASSERT_TRUE(ms::server::CommonUtil::IsDirectoryExist(path3)); - status = server::CommonUtil::DeleteDirectory(path1); + status = ms::server::CommonUtil::DeleteDirectory(path1); ASSERT_TRUE(status.ok()); //test again - status = server::CommonUtil::DeleteDirectory(path1); + status = ms::server::CommonUtil::DeleteDirectory(path1); ASSERT_TRUE(status.ok()); - ASSERT_FALSE(server::CommonUtil::IsDirectoryExist(path1)); - ASSERT_FALSE(server::CommonUtil::IsFileExist(path1)); + ASSERT_FALSE(ms::server::CommonUtil::IsDirectoryExist(path1)); + ASSERT_FALSE(ms::server::CommonUtil::IsFileExist(path1)); - std::string exe_path = server::CommonUtil::GetExePath(); + std::string exe_path = ms::server::CommonUtil::GetExePath(); ASSERT_FALSE(exe_path.empty()); time_t tt; - time( &tt ); + time(&tt); tm time_struct; memset(&time_struct, 0, sizeof(tm)); - server::CommonUtil::ConvertTime(tt, time_struct); + ms::server::CommonUtil::ConvertTime(tt, time_struct); ASSERT_GT(time_struct.tm_year, 0); ASSERT_GT(time_struct.tm_mon, 0); ASSERT_GT(time_struct.tm_mday, 0); - server::CommonUtil::ConvertTime(time_struct, tt); + ms::server::CommonUtil::ConvertTime(time_struct, tt); ASSERT_GT(tt, 0); - bool res = server::CommonUtil::TimeStrToTime("2019-03-23", tt, time_struct); + bool res = ms::server::CommonUtil::TimeStrToTime("2019-03-23", tt, time_struct); ASSERT_EQ(time_struct.tm_year, 119); ASSERT_EQ(time_struct.tm_mon, 2); ASSERT_EQ(time_struct.tm_mday, 23); @@ -109,45 +110,43 @@ TEST(UtilTest, COMMON_TEST) { TEST(UtilTest, STRINGFUNCTIONS_TEST) { std::string str = " test zilliz"; - server::StringHelpFunctions::TrimStringBlank(str); + ms::server::StringHelpFunctions::TrimStringBlank(str); ASSERT_EQ(str, "test zilliz"); str = "\"test zilliz\""; - server::StringHelpFunctions::TrimStringQuote(str, "\""); + ms::server::StringHelpFunctions::TrimStringQuote(str, "\""); ASSERT_EQ(str, "test zilliz"); str = "a,b,c"; std::vector result; - auto status = server::StringHelpFunctions::SplitStringByDelimeter(str , ",", result); + auto status = ms::server::StringHelpFunctions::SplitStringByDelimeter(str, ",", result); ASSERT_TRUE(status.ok()); ASSERT_EQ(result.size(), 3UL); result.clear(); - status = server::StringHelpFunctions::SplitStringByQuote(str , ",", "\"", result); + status = ms::server::StringHelpFunctions::SplitStringByQuote(str, ",", "\"", result); ASSERT_TRUE(status.ok()); ASSERT_EQ(result.size(), 3UL); result.clear(); - status = server::StringHelpFunctions::SplitStringByQuote(str , ",", "", result); + status = ms::server::StringHelpFunctions::SplitStringByQuote(str, ",", "", result); ASSERT_TRUE(status.ok()); ASSERT_EQ(result.size(), 3UL); str = "55,\"aa,gg,yy\",b"; result.clear(); - status = server::StringHelpFunctions::SplitStringByQuote(str , ",", "\"", result); + status = ms::server::StringHelpFunctions::SplitStringByQuote(str, ",", "\"", result); ASSERT_TRUE(status.ok()); ASSERT_EQ(result.size(), 3UL); - - } TEST(UtilTest, BLOCKINGQUEUE_TEST) { - server::BlockingQueue bq; + ms::server::BlockingQueue bq; static const size_t count = 10; bq.SetCapacity(count); - for(size_t i = 1; i <= count; i++) { + for (size_t i = 1; i <= count; i++) { std::string id = "No." + std::to_string(i); bq.Put(id); } @@ -161,7 +160,7 @@ TEST(UtilTest, BLOCKINGQUEUE_TEST) { str = bq.Back(); ASSERT_EQ(str, "No." + std::to_string(count)); - for(size_t i = 1; i <= count; i++) { + for (size_t i = 1; i <= count; i++) { std::string id = "No." + std::to_string(i); str = bq.Take(); ASSERT_EQ(id, str); @@ -171,54 +170,54 @@ TEST(UtilTest, BLOCKINGQUEUE_TEST) { } TEST(UtilTest, LOG_TEST) { - auto status = server::InitLog(LOG_FILE_PATH); + auto status = ms::server::InitLog(LOG_FILE_PATH); ASSERT_TRUE(status.ok()); EXPECT_FALSE(el::Loggers::hasFlag(el::LoggingFlag::NewLineForContainer)); EXPECT_FALSE(el::Loggers::hasFlag(el::LoggingFlag::LogDetailedCrashReason)); - std::string fname = server::CommonUtil::GetFileName(LOG_FILE_PATH); + std::string fname = ms::server::CommonUtil::GetFileName(LOG_FILE_PATH); ASSERT_EQ(fname, "log_config.conf"); } TEST(UtilTest, TIMERECORDER_TEST) { - for(int64_t log_level = 0; log_level <= 6; log_level++) { - if(log_level == 5) { + for (int64_t log_level = 0; log_level <= 6; log_level++) { + if (log_level == 5) { continue; //skip fatal } - TimeRecorder rc("time", log_level); + ms::TimeRecorder rc("time", log_level); rc.RecordSection("end"); } } TEST(UtilTest, STATUS_TEST) { - auto status = Status::OK(); + auto status = ms::Status::OK(); std::string str = status.ToString(); ASSERT_FALSE(str.empty()); - status = Status(DB_ERROR, "mistake"); - ASSERT_EQ(status.code(), DB_ERROR); + status = ms::Status(ms::DB_ERROR, "mistake"); + ASSERT_EQ(status.code(), ms::DB_ERROR); str = status.ToString(); ASSERT_FALSE(str.empty()); - status = Status(DB_NOT_FOUND, "mistake"); - ASSERT_EQ(status.code(), DB_NOT_FOUND); + status = ms::Status(ms::DB_NOT_FOUND, "mistake"); + ASSERT_EQ(status.code(), ms::DB_NOT_FOUND); str = status.ToString(); ASSERT_FALSE(str.empty()); - status = Status(DB_ALREADY_EXIST, "mistake"); - ASSERT_EQ(status.code(), DB_ALREADY_EXIST); + status = ms::Status(ms::DB_ALREADY_EXIST, "mistake"); + ASSERT_EQ(status.code(), ms::DB_ALREADY_EXIST); str = status.ToString(); ASSERT_FALSE(str.empty()); - status = Status(DB_META_TRANSACTION_FAILED, "mistake"); - ASSERT_EQ(status.code(), DB_META_TRANSACTION_FAILED); + status = ms::Status(ms::DB_META_TRANSACTION_FAILED, "mistake"); + ASSERT_EQ(status.code(), ms::DB_META_TRANSACTION_FAILED); str = status.ToString(); ASSERT_FALSE(str.empty()); - auto status_copy = Status::OK(); + auto status_copy = ms::Status::OK(); CopyStatus(status_copy, status); - ASSERT_EQ(status.code(), DB_META_TRANSACTION_FAILED); + ASSERT_EQ(status.code(), ms::DB_META_TRANSACTION_FAILED); auto status_ref(status); ASSERT_EQ(status_ref.code(), status.code()); @@ -231,123 +230,140 @@ TEST(UtilTest, STATUS_TEST) { TEST(ValidationUtilTest, VALIDATE_TABLENAME_TEST) { std::string table_name = "Normal123_"; - auto status = server::ValidationUtil::ValidateTableName(table_name); + auto status = ms::server::ValidationUtil::ValidateTableName(table_name); ASSERT_TRUE(status.ok()); table_name = "12sds"; - status = server::ValidationUtil::ValidateTableName(table_name); - ASSERT_EQ(status.code(), SERVER_INVALID_TABLE_NAME); + status = ms::server::ValidationUtil::ValidateTableName(table_name); + ASSERT_EQ(status.code(), ms::SERVER_INVALID_TABLE_NAME); table_name = ""; - status = server::ValidationUtil::ValidateTableName(table_name); - ASSERT_EQ(status.code(), SERVER_INVALID_TABLE_NAME); + status = ms::server::ValidationUtil::ValidateTableName(table_name); + ASSERT_EQ(status.code(), ms::SERVER_INVALID_TABLE_NAME); table_name = "_asdasd"; - status = server::ValidationUtil::ValidateTableName(table_name); - ASSERT_EQ(status.code(), SERVER_SUCCESS); + status = ms::server::ValidationUtil::ValidateTableName(table_name); + ASSERT_EQ(status.code(), ms::SERVER_SUCCESS); table_name = "!@#!@"; - status = server::ValidationUtil::ValidateTableName(table_name); - ASSERT_EQ(status.code(), SERVER_INVALID_TABLE_NAME); + status = ms::server::ValidationUtil::ValidateTableName(table_name); + ASSERT_EQ(status.code(), ms::SERVER_INVALID_TABLE_NAME); table_name = "_!@#!@"; - status = server::ValidationUtil::ValidateTableName(table_name); - ASSERT_EQ(status.code(), SERVER_INVALID_TABLE_NAME); + status = ms::server::ValidationUtil::ValidateTableName(table_name); + ASSERT_EQ(status.code(), ms::SERVER_INVALID_TABLE_NAME); table_name = "中文"; - status = server::ValidationUtil::ValidateTableName(table_name); - ASSERT_EQ(status.code(), SERVER_INVALID_TABLE_NAME); + status = ms::server::ValidationUtil::ValidateTableName(table_name); + ASSERT_EQ(status.code(), ms::SERVER_INVALID_TABLE_NAME); table_name = std::string(10000, 'a'); - status = server::ValidationUtil::ValidateTableName(table_name); - ASSERT_EQ(status.code(), SERVER_INVALID_TABLE_NAME); + status = ms::server::ValidationUtil::ValidateTableName(table_name); + ASSERT_EQ(status.code(), ms::SERVER_INVALID_TABLE_NAME); } TEST(ValidationUtilTest, VALIDATE_DIMENSION_TEST) { - ASSERT_EQ(server::ValidationUtil::ValidateTableDimension(-1).code(), SERVER_INVALID_VECTOR_DIMENSION); - ASSERT_EQ(server::ValidationUtil::ValidateTableDimension(0).code(), SERVER_INVALID_VECTOR_DIMENSION); - ASSERT_EQ(server::ValidationUtil::ValidateTableDimension(16385).code(), SERVER_INVALID_VECTOR_DIMENSION); - ASSERT_EQ(server::ValidationUtil::ValidateTableDimension(16384).code(), SERVER_SUCCESS); - ASSERT_EQ(server::ValidationUtil::ValidateTableDimension(1).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableDimension(-1).code(), ms::SERVER_INVALID_VECTOR_DIMENSION); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableDimension(0).code(), ms::SERVER_INVALID_VECTOR_DIMENSION); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableDimension(16385).code(), ms::SERVER_INVALID_VECTOR_DIMENSION); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableDimension(16384).code(), ms::SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableDimension(1).code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_INDEX_TEST) { - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexType((int)engine::EngineType::INVALID).code(), SERVER_INVALID_INDEX_TYPE); - for(int i = 1; i <= (int)engine::EngineType::MAX_VALUE; i++) { - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexType(i).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexType((int) ms::engine::EngineType::INVALID).code(), + ms::SERVER_INVALID_INDEX_TYPE); + for (int i = 1; i <= (int) ms::engine::EngineType::MAX_VALUE; i++) { + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexType(i).code(), ms::SERVER_SUCCESS); } - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexType((int)engine::EngineType::MAX_VALUE + 1).code(), SERVER_INVALID_INDEX_TYPE); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexType((int) ms::engine::EngineType::MAX_VALUE + 1).code(), + ms::SERVER_INVALID_INDEX_TYPE); - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexNlist(0).code(), SERVER_INVALID_INDEX_NLIST); - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexNlist(100).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexNlist(0).code(), ms::SERVER_INVALID_INDEX_NLIST); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexNlist(100).code(), ms::SERVER_SUCCESS); - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexFileSize(0).code(), SERVER_INVALID_INDEX_FILE_SIZE); - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexFileSize(100).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexFileSize(0).code(), ms::SERVER_INVALID_INDEX_FILE_SIZE); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexFileSize(100).code(), ms::SERVER_SUCCESS); - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexMetricType(0).code(), SERVER_INVALID_INDEX_METRIC_TYPE); - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexMetricType(1).code(), SERVER_SUCCESS); - ASSERT_EQ(server::ValidationUtil::ValidateTableIndexMetricType(2).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexMetricType(0).code(), ms::SERVER_INVALID_INDEX_METRIC_TYPE); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexMetricType(1).code(), ms::SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateTableIndexMetricType(2).code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_TOPK_TEST) { - engine::meta::TableSchema schema; - ASSERT_EQ(server::ValidationUtil::ValidateSearchTopk(10, schema).code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateSearchTopk(65536, schema).code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateSearchTopk(0, schema).code(), SERVER_SUCCESS); + ms::engine::meta::TableSchema schema; + ASSERT_EQ(ms::server::ValidationUtil::ValidateSearchTopk(10, schema).code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateSearchTopk(65536, schema).code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateSearchTopk(0, schema).code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_NPROBE_TEST) { - engine::meta::TableSchema schema; + ms::engine::meta::TableSchema schema; schema.nlist_ = 100; - ASSERT_EQ(server::ValidationUtil::ValidateSearchNprobe(10, schema).code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateSearchNprobe(0, schema).code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateSearchNprobe(101, schema).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateSearchNprobe(10, schema).code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateSearchNprobe(0, schema).code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateSearchNprobe(101, schema).code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_GPU_TEST) { - ASSERT_EQ(server::ValidationUtil::ValidateGpuIndex(0).code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateGpuIndex(100).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateGpuIndex(0).code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateGpuIndex(100).code(), ms::SERVER_SUCCESS); size_t memory = 0; - ASSERT_EQ(server::ValidationUtil::GetGpuMemory(0, memory).code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::GetGpuMemory(100, memory).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::GetGpuMemory(0, memory).code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::GetGpuMemory(100, memory).code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_IPADDRESS_TEST) { - ASSERT_EQ(server::ValidationUtil::ValidateIpAddress("127.0.0.1").code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateIpAddress("not ip").code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateIpAddress("127.0.0.1").code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateIpAddress("not ip").code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_NUMBER_TEST) { - ASSERT_EQ(server::ValidationUtil::ValidateStringIsNumber("1234").code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateStringIsNumber("not number").code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateStringIsNumber("1234").code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateStringIsNumber("not number").code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_BOOL_TEST) { std::string str = "true"; - ASSERT_EQ(server::ValidationUtil::ValidateStringIsBool(str).code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateStringIsBool(str).code(), ms::SERVER_SUCCESS); str = "not bool"; - ASSERT_NE(server::ValidationUtil::ValidateStringIsBool(str).code(), SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateStringIsBool(str).code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_DOUBLE_TEST) { - ASSERT_EQ(server::ValidationUtil::ValidateStringIsFloat("2.5").code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateStringIsFloat("not double").code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateStringIsFloat("2.5").code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateStringIsFloat("not double").code(), ms::SERVER_SUCCESS); } TEST(ValidationUtilTest, VALIDATE_DBURI_TEST) { - ASSERT_EQ(server::ValidationUtil::ValidateDbURI("sqlite://:@:/").code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateDbURI("xxx://:@:/").code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateDbURI("not uri").code(), SERVER_SUCCESS); - ASSERT_EQ(server::ValidationUtil::ValidateDbURI("mysql://root:123456@127.0.0.1:3303/milvus").code(), SERVER_SUCCESS); - ASSERT_NE(server::ValidationUtil::ValidateDbURI("mysql://root:123456@127.0.0.1:port/milvus").code(), SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateDbURI("sqlite://:@:/").code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateDbURI("xxx://:@:/").code(), ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateDbURI("not uri").code(), ms::SERVER_SUCCESS); + ASSERT_EQ(ms::server::ValidationUtil::ValidateDbURI("mysql://root:123456@127.0.0.1:3303/milvus").code(), + ms::SERVER_SUCCESS); + ASSERT_NE(ms::server::ValidationUtil::ValidateDbURI("mysql://root:123456@127.0.0.1:port/milvus").code(), + ms::SERVER_SUCCESS); } -TEST(UtilTest, ROLLOUTHANDLER_TEST){ +TEST(UtilTest, ROLLOUTHANDLER_TEST) { std::string dir1 = "/tmp/milvus_test"; std::string dir2 = "/tmp/milvus_test/log_test"; - std::string filename[6] = {"log_global.log", "log_debug.log", "log_warning.log", "log_trace.log", "log_error.log", "log_fatal.log"}; - el::Level list[6] = {el::Level::Global, el::Level::Debug, el::Level::Warning, el::Level::Trace, el::Level::Error, el::Level::Fatal}; + std::string filename[6] = { + "log_global.log", + "log_debug.log", + "log_warning.log", + "log_trace.log", + "log_error.log", + "log_fatal.log"}; + + el::Level list[6] = { + el::Level::Global, + el::Level::Debug, + el::Level::Warning, + el::Level::Trace, + el::Level::Error, + el::Level::Fatal}; mkdir(dir1.c_str(), S_IRWXU); mkdir(dir2.c_str(), S_IRWXU); @@ -358,7 +374,7 @@ TEST(UtilTest, ROLLOUTHANDLER_TEST){ file.open(tmp.c_str()); file << "zilliz" << std::endl; - server::RolloutHandler(tmp.c_str(), 0, list[i]); + ms::server::RolloutHandler(tmp.c_str(), 0, list[i]); tmp.append(".1"); std::ifstream file2; @@ -369,4 +385,4 @@ TEST(UtilTest, ROLLOUTHANDLER_TEST){ ASSERT_EQ(tmp2, "zilliz"); } boost::filesystem::remove_all(dir2); -} \ No newline at end of file +} diff --git a/cpp/unittest/wrapper/CMakeLists.txt b/cpp/unittest/wrapper/CMakeLists.txt index 633cbaf0534a9eba5ad3b66c5061236a8ab2634d..8eae47b3d4c710249b311b727939ab3765a41eab 100644 --- a/cpp/unittest/wrapper/CMakeLists.txt +++ b/cpp/unittest/wrapper/CMakeLists.txt @@ -17,8 +17,7 @@ # under the License. #------------------------------------------------------------------------------- -include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include") -link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} test_files) set(wrapper_files ${MILVUS_ENGINE_SRC}/wrapper/DataTransfer.cpp @@ -31,13 +30,13 @@ set(util_files ${MILVUS_ENGINE_SRC}/utils/Status.cpp ) -set(knowhere_libs - knowhere - cudart - cublas - ) +add_executable(test_wrapper + ${test_files} + ${wrapper_files} + ${util_files}) -add_executable(wrapper_test wrapper_test.cpp ${wrapper_files} ${util_files}) -target_link_libraries(wrapper_test ${knowhere_libs} ${unittest_libs}) +target_link_libraries(test_wrapper + knowhere + ${unittest_libs}) -install(TARGETS wrapper_test DESTINATION unittest) \ No newline at end of file +install(TARGETS test_wrapper DESTINATION unittest) \ No newline at end of file diff --git a/cpp/unittest/wrapper/wrapper_test.cpp b/cpp/unittest/wrapper/test_wrapper.cpp similarity index 70% rename from cpp/unittest/wrapper/wrapper_test.cpp rename to cpp/unittest/wrapper/test_wrapper.cpp index 157f222cb4b1d990941d368cf3f96618533dfd0c..4dd2e33bd5427dae635a75a9d03625e9d237f1c6 100644 --- a/cpp/unittest/wrapper/wrapper_test.cpp +++ b/cpp/unittest/wrapper/test_wrapper.cpp @@ -16,18 +16,21 @@ // under the License. #include "utils/easylogging++.h" -#include "src/wrapper/VecIndex.h" +#include "wrapper/VecIndex.h" #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h" -#include "utils.h" +#include "wrapper/utils.h" #include - INITIALIZE_EASYLOGGINGPP -using namespace zilliz::milvus::engine; -//using namespace zilliz::knowhere; +namespace { + +namespace ms = milvus::engine; +namespace kw = knowhere; + +} // namespace using ::testing::TestWithParam; using ::testing::Values; @@ -39,49 +42,49 @@ constexpr int64_t DEVICE_ID = 0; class ParamGenerator { public: - static ParamGenerator& GetInstance(){ + static ParamGenerator &GetInstance() { static ParamGenerator instance; return instance; } - Config Gen(const IndexType& type) { + kw::Config Gen(const ms::IndexType &type) { switch (type) { - case IndexType::FAISS_IDMAP: { - auto tempconf = std::make_shared(); - tempconf->metric_type = zilliz::knowhere::METRICTYPE::L2; + case ms::IndexType::FAISS_IDMAP: { + auto tempconf = std::make_shared(); + tempconf->metric_type = knowhere::METRICTYPE::L2; return tempconf; } - case IndexType::FAISS_IVFFLAT_CPU: - case IndexType::FAISS_IVFFLAT_GPU: - case IndexType::FAISS_IVFFLAT_MIX: { - auto tempconf = std::make_shared(); + case ms::IndexType::FAISS_IVFFLAT_CPU: + case ms::IndexType::FAISS_IVFFLAT_GPU: + case ms::IndexType::FAISS_IVFFLAT_MIX: { + auto tempconf = std::make_shared(); tempconf->nlist = 100; tempconf->nprobe = 16; - tempconf->metric_type = zilliz::knowhere::METRICTYPE::L2; + tempconf->metric_type = knowhere::METRICTYPE::L2; return tempconf; } - case IndexType::FAISS_IVFSQ8_CPU: - case IndexType::FAISS_IVFSQ8_GPU: - case IndexType::FAISS_IVFSQ8_MIX: { - auto tempconf = std::make_shared(); + case ms::IndexType::FAISS_IVFSQ8_CPU: + case ms::IndexType::FAISS_IVFSQ8_GPU: + case ms::IndexType::FAISS_IVFSQ8_MIX: { + auto tempconf = std::make_shared(); tempconf->nlist = 100; tempconf->nprobe = 16; tempconf->nbits = 8; - tempconf->metric_type = zilliz::knowhere::METRICTYPE::L2; + tempconf->metric_type = knowhere::METRICTYPE::L2; return tempconf; } - case IndexType::FAISS_IVFPQ_CPU: - case IndexType::FAISS_IVFPQ_GPU: { - auto tempconf = std::make_shared(); + case ms::IndexType::FAISS_IVFPQ_CPU: + case ms::IndexType::FAISS_IVFPQ_GPU: { + auto tempconf = std::make_shared(); tempconf->nlist = 100; tempconf->nprobe = 16; tempconf->nbits = 8; tempconf->m = 8; - tempconf->metric_type = zilliz::knowhere::METRICTYPE::L2; + tempconf->metric_type = knowhere::METRICTYPE::L2; return tempconf; } - case IndexType::NSG_MIX: { - auto tempconf = std::make_shared(); + case ms::IndexType::NSG_MIX: { + auto tempconf = std::make_shared(); tempconf->nlist = 100; tempconf->nprobe = 16; tempconf->search_length = 8; @@ -89,7 +92,7 @@ class ParamGenerator { tempconf->search_length = 40; // TODO(linxj): be 20 when search tempconf->out_degree = 60; tempconf->candidate_pool_size = 200; - tempconf->metric_type = zilliz::knowhere::METRICTYPE::L2; + tempconf->metric_type = knowhere::METRICTYPE::L2; return tempconf; } } @@ -97,10 +100,13 @@ class ParamGenerator { }; class KnowhereWrapperTest - : public TestWithParam<::std::tuple> { + : public TestWithParam<::std::tuple> { protected: void SetUp() override { - zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID,1024*1024*200, 1024*1024*300, 2); + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, + 1024 * 1024 * 200, + 1024 * 1024 * 300, + 2); std::string generator_type; std::tie(index_type, generator_type, dim, nb, nq, k) = GetParam(); @@ -115,11 +121,12 @@ class KnowhereWrapperTest conf->d = dim; conf->gpu_id = DEVICE_ID; } + void TearDown() override { - zilliz::knowhere::FaissGpuResourceMgr::GetInstance().Free(); + knowhere::FaissGpuResourceMgr::GetInstance().Free(); } - void AssertResult(const std::vector &ids, const std::vector &dis) { + void AssertResult(const std::vector &ids, const std::vector &dis) { EXPECT_EQ(ids.size(), nq * k); EXPECT_EQ(dis.size(), nq * k); @@ -146,8 +153,8 @@ class KnowhereWrapperTest } protected: - IndexType index_type; - Config conf; + ms::IndexType index_type; + kw::Config conf; int dim = DIM; int nb = NB; @@ -155,27 +162,27 @@ class KnowhereWrapperTest int k = 10; std::vector xb; std::vector xq; - std::vector ids; + std::vector ids; - VecIndexPtr index_ = nullptr; + ms::VecIndexPtr index_ = nullptr; // Ground Truth - std::vector gt_ids; + std::vector gt_ids; std::vector gt_dis; }; INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest, Values( //["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"] - std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default", 64, 100000, 10, 10), - std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default", DIM, NB, 10, 10), - std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default", 64, 100000, 10, 10), - std::make_tuple(IndexType::FAISS_IVFSQ8_CPU, "Default", DIM, NB, 10, 10), -// std::make_tuple(IndexType::FAISS_IVFSQ8_GPU, "Default", DIM, NB, 10, 10), - std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default", DIM, NB, 10, 10), + std::make_tuple(ms::IndexType::FAISS_IVFFLAT_CPU, "Default", 64, 100000, 10, 10), + std::make_tuple(ms::IndexType::FAISS_IVFFLAT_GPU, "Default", DIM, NB, 10, 10), + std::make_tuple(ms::IndexType::FAISS_IVFFLAT_MIX, "Default", 64, 100000, 10, 10), + std::make_tuple(ms::IndexType::FAISS_IVFSQ8_CPU, "Default", DIM, NB, 10, 10), + std::make_tuple(ms::IndexType::FAISS_IVFSQ8_GPU, "Default", DIM, NB, 10, 10), + std::make_tuple(ms::IndexType::FAISS_IVFSQ8_MIX, "Default", DIM, NB, 10, 10), // std::make_tuple(IndexType::NSG_MIX, "Default", 128, 250000, 10, 10), // std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 250000, 10, 10), - std::make_tuple(IndexType::FAISS_IDMAP, "Default", 64, 100000, 10, 10) + std::make_tuple(ms::IndexType::FAISS_IDMAP, "Default", 64, 100000, 10, 10) ) ); @@ -213,7 +220,7 @@ TEST_P(KnowhereWrapperTest, TO_GPU_TEST) { { std::string file_location = "/tmp/knowhere_gpu_file"; write_index(index_, file_location); - auto new_index = read_index(file_location); + auto new_index = ms::read_index(file_location); auto dev_idx = new_index->CopyToGpu(DEVICE_ID); for (int i = 0; i < 10; ++i) { @@ -254,7 +261,7 @@ TEST_P(KnowhereWrapperTest, SERIALIZE_TEST) { { std::string file_location = "/tmp/knowhere"; write_index(index_, file_location); - auto new_index = read_index(file_location); + auto new_index = ms::read_index(file_location); EXPECT_EQ(new_index->GetType(), ConvertToCpuIndexType(index_type)); EXPECT_EQ(new_index->Dimension(), index_->Dimension()); EXPECT_EQ(new_index->Count(), index_->Count()); diff --git a/cpp/unittest/wrapper/utils.cpp b/cpp/unittest/wrapper/utils.cpp index f404f000e9f753950ce62130619469c59d19dfd5..f2bb83b482b5d8a90d0a8e92135be496e6a83912 100644 --- a/cpp/unittest/wrapper/utils.cpp +++ b/cpp/unittest/wrapper/utils.cpp @@ -18,12 +18,12 @@ #include -#include "utils.h" +#include "wrapper/utils.h" - -void DataGenBase::GenData(const int &dim, const int &nb, const int &nq, - float *xb, float *xq, long *ids, - const int &k, long *gt_ids, float *gt_dis) { +void +DataGenBase::GenData(const int &dim, const int &nb, const int &nq, + float *xb, float *xq, int64_t *ids, + const int &k, int64_t *gt_ids, float *gt_dis) { for (auto i = 0; i < nb; ++i) { for (auto j = 0; j < dim; ++j) { //p_data[i * d + j] = float(base + i); @@ -42,15 +42,16 @@ void DataGenBase::GenData(const int &dim, const int &nb, const int &nq, index.search(nq, xq, k, gt_dis, gt_ids); } -void DataGenBase::GenData(const int &dim, - const int &nb, - const int &nq, - std::vector &xb, - std::vector &xq, - std::vector &ids, - const int &k, - std::vector >_ids, - std::vector >_dis) { +void +DataGenBase::GenData(const int &dim, + const int &nb, + const int &nq, + std::vector &xb, + std::vector &xq, + std::vector &ids, + const int &k, + std::vector >_ids, + std::vector >_dis) { xb.resize(nb * dim); xq.resize(nq * dim); ids.resize(nb); diff --git a/cpp/unittest/wrapper/utils.h b/cpp/unittest/wrapper/utils.h index 7a73d8471316789257949945645481b510487442..ff4ce9c23a52c038a298a60b7b06cb671927821a 100644 --- a/cpp/unittest/wrapper/utils.h +++ b/cpp/unittest/wrapper/utils.h @@ -24,25 +24,23 @@ #include #include - class DataGenBase; using DataGenPtr = std::shared_ptr; - class DataGenBase { public: - virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids, - const int &k, long *gt_ids, float *gt_dis); + virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, int64_t *ids, + const int &k, int64_t *gt_ids, float *gt_dis); virtual void GenData(const int &dim, const int &nb, const int &nq, std::vector &xb, std::vector &xq, - std::vector &ids, + std::vector &ids, const int &k, - std::vector >_ids, + std::vector >_ids, std::vector >_dis); }; diff --git a/cpp/version.h.macro b/cpp/version.h.macro index 2463407a0dfe49603c60440cf0f71107dd9c81af..454d8a990a503394efba2f22bda8d620cf5ddcf1 100644 --- a/cpp/version.h.macro +++ b/cpp/version.h.macro @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #pragma once #define MILVUS_VERSION "@MILVUS_VERSION@"