提交 0c0c8ee1 编写于 作者: J JinHai-CN

Merge from main branch


Former-commit-id: e74d537ac6217dbf59e1ee2d7d16307fef760f82
container('milvus-build-env') { container('milvus-build-env') {
timeout(time: 40, unit: 'MINUTES') { timeout(time: 120, unit: 'MINUTES') {
gitlabCommitStatus(name: 'Build Engine') { gitlabCommitStatus(name: 'Build Engine') {
dir ("milvus_engine") { dir ("milvus_engine") {
try { try {
......
container('milvus-build-env') { container('milvus-build-env') {
timeout(time: 40, unit: 'MINUTES') { timeout(time: 120, unit: 'MINUTES') {
gitlabCommitStatus(name: 'Build Engine') { gitlabCommitStatus(name: 'Build Engine') {
dir ("milvus_engine") { dir ("milvus_engine") {
try { try {
......
...@@ -22,6 +22,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -22,6 +22,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-624 - Search vectors failed if time ranges long enough - MS-624 - Search vectors failed if time ranges long enough
- MS-652 - IVFSQH quantization double free - MS-652 - IVFSQH quantization double free
- MS-605 - Server going down during searching vectors - MS-605 - Server going down during searching vectors
- MS-654 - Describe index timeout when building index
## Improvement ## Improvement
- MS-552 - Add and change the easylogging library - MS-552 - Add and change the easylogging library
...@@ -43,6 +44,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -43,6 +44,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-614 - Preload table at startup - MS-614 - Preload table at startup
- MS-626 - Refactor DataObj to support cache any type data - MS-626 - Refactor DataObj to support cache any type data
- MS-648 - Improve unittest - MS-648 - Improve unittest
- MS-655 - Upgrade SPTAG
## New Feature ## New Feature
- MS-627 - Integrate new index: IVFSQHybrid - MS-627 - Integrate new index: IVFSQHybrid
......
...@@ -125,10 +125,6 @@ set(MILVUS_SOURCE_DIR ${PROJECT_SOURCE_DIR}) ...@@ -125,10 +125,6 @@ set(MILVUS_SOURCE_DIR ${PROJECT_SOURCE_DIR})
set(MILVUS_BINARY_DIR ${PROJECT_BINARY_DIR}) set(MILVUS_BINARY_DIR ${PROJECT_BINARY_DIR})
set(MILVUS_ENGINE_SRC ${PROJECT_SOURCE_DIR}/src) set(MILVUS_ENGINE_SRC ${PROJECT_SOURCE_DIR}/src)
if (CUSTOMIZATION)
add_definitions(-DCUSTOMIZATION)
endif (CUSTOMIZATION)
include(ExternalProject) include(ExternalProject)
include(DefineOptions) include(DefineOptions)
include(BuildUtils) include(BuildUtils)
......
...@@ -88,6 +88,11 @@ function(ExternalProject_Create_Cache project_name package_file install_path cac ...@@ -88,6 +88,11 @@ function(ExternalProject_Create_Cache project_name package_file install_path cac
file(REMOVE ${package_file}) file(REMOVE ${package_file})
endif() endif()
string(REGEX REPLACE "(.+)/.+$" "\\1" package_dir ${package_file})
if(NOT EXISTS ${package_dir})
file(MAKE_DIRECTORY ${package_dir})
endif()
message(STATUS "Will create cached package file: ${package_file}") message(STATUS "Will create cached package file: ${package_file}")
ExternalProject_Add_Step(${project_name} package ExternalProject_Add_Step(${project_name} package
......
...@@ -158,6 +158,10 @@ if(USE_JFROG_CACHE STREQUAL "ON") ...@@ -158,6 +158,10 @@ if(USE_JFROG_CACHE STREQUAL "ON")
endif() endif()
set(THIRDPARTY_PACKAGE_CACHE "${THIRDPARTY_DIR}/cache") set(THIRDPARTY_PACKAGE_CACHE "${THIRDPARTY_DIR}/cache")
if(NOT EXISTS ${THIRDPARTY_PACKAGE_CACHE})
message(STATUS "Will create cached directory: ${THIRDPARTY_PACKAGE_CACHE}")
file(MAKE_DIRECTORY ${THIRDPARTY_PACKAGE_CACHE})
endif()
endif() endif()
macro(resolve_dependency DEPENDENCY_NAME) macro(resolve_dependency DEPENDENCY_NAME)
...@@ -324,8 +328,8 @@ if(DEFINED ENV{MILVUS_SQLITE_ORM_URL}) ...@@ -324,8 +328,8 @@ if(DEFINED ENV{MILVUS_SQLITE_ORM_URL})
set(SQLITE_ORM_SOURCE_URL "$ENV{MILVUS_SQLITE_ORM_URL}") set(SQLITE_ORM_SOURCE_URL "$ENV{MILVUS_SQLITE_ORM_URL}")
else() else()
set(SQLITE_ORM_SOURCE_URL set(SQLITE_ORM_SOURCE_URL
"http://192.168.1.105:6060/Test/sqlite_orm/-/archive/master/sqlite_orm-master.zip") # "http://192.168.1.105:6060/Test/sqlite_orm/-/archive/master/sqlite_orm-master.zip")
# "https://github.com/fnc12/sqlite_orm/archive/${SQLITE_ORM_VERSION}.zip") "https://github.com/fnc12/sqlite_orm/archive/${SQLITE_ORM_VERSION}.zip")
endif() endif()
set(SQLITE_ORM_MD5 "ba9a405a8a1421c093aa8ce988ff8598") set(SQLITE_ORM_MD5 "ba9a405a8a1421c093aa8ce988ff8598")
...@@ -372,7 +376,7 @@ else() ...@@ -372,7 +376,7 @@ else()
set(GRPC_SOURCE_URL set(GRPC_SOURCE_URL
"https://github.com/youny626/grpc-milvus/archive/${GRPC_VERSION}.zip") "https://github.com/youny626/grpc-milvus/archive/${GRPC_VERSION}.zip")
endif() endif()
set(GRPC_MD5 "fdd2656424c0e0e046b21354513fc70f") set(GRPC_MD5 "0362ba219f59432c530070b5f5c3df73")
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
......
...@@ -39,27 +39,6 @@ mysql_exc "GRANT ALL PRIVILEGES ON ${MYSQL_DB_NAME}.* TO '${MYSQL_USER_NAME}'@'% ...@@ -39,27 +39,6 @@ mysql_exc "GRANT ALL PRIVILEGES ON ${MYSQL_DB_NAME}.* TO '${MYSQL_USER_NAME}'@'%
mysql_exc "FLUSH PRIVILEGES;" mysql_exc "FLUSH PRIVILEGES;"
mysql_exc "USE ${MYSQL_DB_NAME};" mysql_exc "USE ${MYSQL_DB_NAME};"
MYSQL_USER_NAME=root
MYSQL_PASSWORD=Fantast1c
MYSQL_HOST='192.168.1.194'
MYSQL_PORT='3306'
MYSQL_DB_NAME=milvus_`date +%s%N`
function mysql_exc()
{
cmd=$1
mysql -h${MYSQL_HOST} -u${MYSQL_USER_NAME} -p${MYSQL_PASSWORD} -e "${cmd}"
if [ $? -ne 0 ]; then
echo "mysql $cmd run failed"
fi
}
mysql_exc "CREATE DATABASE IF NOT EXISTS ${MYSQL_DB_NAME};"
mysql_exc "GRANT ALL PRIVILEGES ON ${MYSQL_DB_NAME}.* TO '${MYSQL_USER_NAME}'@'%';"
mysql_exc "FLUSH PRIVILEGES;"
mysql_exc "USE ${MYSQL_DB_NAME};"
# get baseline # get baseline
${LCOV_CMD} -c -i -d ${DIR_GCNO} -o "${FILE_INFO_BASE}" ${LCOV_CMD} -c -i -d ${DIR_GCNO} -o "${FILE_INFO_BASE}"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
......
...@@ -26,6 +26,11 @@ include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus) ...@@ -26,6 +26,11 @@ include_directories(${MILVUS_ENGINE_SRC}/grpc/gen-milvus)
#this statement must put here, since the CORE_INCLUDE_DIRS is defined in code/CMakeList.txt #this statement must put here, since the CORE_INCLUDE_DIRS is defined in code/CMakeList.txt
add_subdirectory(index) add_subdirectory(index)
if (CUSTOMIZATION)
add_definitions(-DCUSTOMIZATION)
endif (CUSTOMIZATION)
set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE) set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
foreach (dir ${CORE_INCLUDE_DIRS}) foreach (dir ${CORE_INCLUDE_DIRS})
include_directories(${dir}) include_directories(${dir})
...@@ -182,8 +187,6 @@ target_link_libraries(milvus_server ...@@ -182,8 +187,6 @@ target_link_libraries(milvus_server
install(TARGETS milvus_server DESTINATION bin) install(TARGETS milvus_server DESTINATION bin)
install(FILES install(FILES
${CMAKE_SOURCE_DIR}/src/index/thirdparty/tbb/${CMAKE_SHARED_LIBRARY_PREFIX}tbb${CMAKE_SHARED_LIBRARY_SUFFIX}
${CMAKE_SOURCE_DIR}/src/index/thirdparty/tbb/${CMAKE_SHARED_LIBRARY_PREFIX}tbb${CMAKE_SHARED_LIBRARY_SUFFIX}.2
${CMAKE_BINARY_DIR}/mysqlpp_ep-prefix/src/mysqlpp_ep/lib/${CMAKE_SHARED_LIBRARY_PREFIX}mysqlpp${CMAKE_SHARED_LIBRARY_SUFFIX} ${CMAKE_BINARY_DIR}/mysqlpp_ep-prefix/src/mysqlpp_ep/lib/${CMAKE_SHARED_LIBRARY_PREFIX}mysqlpp${CMAKE_SHARED_LIBRARY_SUFFIX}
${CMAKE_BINARY_DIR}/mysqlpp_ep-prefix/src/mysqlpp_ep/lib/${CMAKE_SHARED_LIBRARY_PREFIX}mysqlpp${CMAKE_SHARED_LIBRARY_SUFFIX}.3 ${CMAKE_BINARY_DIR}/mysqlpp_ep-prefix/src/mysqlpp_ep/lib/${CMAKE_SHARED_LIBRARY_PREFIX}mysqlpp${CMAKE_SHARED_LIBRARY_SUFFIX}.3
${CMAKE_BINARY_DIR}/mysqlpp_ep-prefix/src/mysqlpp_ep/lib/${CMAKE_SHARED_LIBRARY_PREFIX}mysqlpp${CMAKE_SHARED_LIBRARY_SUFFIX}.3.2.4 ${CMAKE_BINARY_DIR}/mysqlpp_ep-prefix/src/mysqlpp_ep/lib/${CMAKE_SHARED_LIBRARY_PREFIX}mysqlpp${CMAKE_SHARED_LIBRARY_SUFFIX}.3.2.4
......
...@@ -251,11 +251,6 @@ DBImpl::InsertVectors(const std::string& table_id, uint64_t n, const float* vect ...@@ -251,11 +251,6 @@ DBImpl::InsertVectors(const std::string& table_id, uint64_t n, const float* vect
Status status; Status status;
milvus::server::CollectInsertMetrics metrics(n, status); milvus::server::CollectInsertMetrics metrics(n, status);
status = mem_mgr_->InsertVectors(table_id, n, vectors, vector_ids); status = mem_mgr_->InsertVectors(table_id, n, vectors, vector_ids);
// std::chrono::microseconds time_span =
// std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
// double average_time = double(time_span.count()) / n;
// ENGINE_LOG_DEBUG << "Insert vectors to cache finished";
return status; return status;
} }
...@@ -359,7 +354,7 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr ...@@ -359,7 +354,7 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr
return Status(DB_ERROR, "Milsvus server is shutdown!"); return Status(DB_ERROR, "Milsvus server is shutdown!");
} }
ENGINE_LOG_DEBUG << "Query by dates for table: " << table_id; ENGINE_LOG_DEBUG << "Query by dates for table: " << table_id << " date range count: " << dates.size();
// get all table files from table // get all table files from table
meta::DatePartionedTableFilesSchema files; meta::DatePartionedTableFilesSchema files;
...@@ -377,7 +372,7 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr ...@@ -377,7 +372,7 @@ DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t npr
} }
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results); status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
return status; return status;
} }
...@@ -389,7 +384,7 @@ DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ ...@@ -389,7 +384,7 @@ DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_
return Status(DB_ERROR, "Milsvus server is shutdown!"); return Status(DB_ERROR, "Milsvus server is shutdown!");
} }
ENGINE_LOG_DEBUG << "Query by file ids for table: " << table_id; ENGINE_LOG_DEBUG << "Query by file ids for table: " << table_id << " date range count: " << dates.size();
// get specified files // get specified files
std::vector<size_t> ids; std::vector<size_t> ids;
...@@ -418,7 +413,7 @@ DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ ...@@ -418,7 +413,7 @@ DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_
} }
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info before query
status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, dates, results); status = QueryAsync(table_id, file_id_array, k, nq, nprobe, vectors, results);
cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query cache::CpuCacheMgr::GetInstance()->PrintInfo(); // print cache info after query
return status; return status;
} }
...@@ -437,14 +432,13 @@ DBImpl::Size(uint64_t& result) { ...@@ -437,14 +432,13 @@ DBImpl::Size(uint64_t& result) {
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Status Status
DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq,
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) { uint64_t nprobe, const float* vectors, QueryResults& results) {
server::CollectQueryMetrics metrics(nq); server::CollectQueryMetrics metrics(nq);
TimeRecorder rc(""); TimeRecorder rc("");
// step 1: get files to search // step 1: get files to search
ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size() ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size();
<< " date range count: " << dates.size();
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(0, k, nq, nprobe, vectors); scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(0, k, nq, nprobe, vectors);
for (auto& file : files) { for (auto& file : files) {
scheduler::TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file); scheduler::TableFileSchemaPtr file_ptr = std::make_shared<meta::TableFileSchema>(file);
...@@ -458,32 +452,7 @@ DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& fi ...@@ -458,32 +452,7 @@ DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& fi
return job->GetStatus(); return job->GetStatus();
} }
// step 3: print time cost information // step 3: construct results
// double load_cost = context->LoadCost();
// double search_cost = context->SearchCost();
// double reduce_cost = context->ReduceCost();
// std::string load_info = TimeRecorder::GetTimeSpanStr(load_cost);
// std::string search_info = TimeRecorder::GetTimeSpanStr(search_cost);
// std::string reduce_info = TimeRecorder::GetTimeSpanStr(reduce_cost);
// if(search_cost > 0.0 || reduce_cost > 0.0) {
// double total_cost = load_cost + search_cost + reduce_cost;
// double load_percent = load_cost/total_cost;
// double search_percent = search_cost/total_cost;
// double reduce_percent = reduce_cost/total_cost;
//
// ENGINE_LOG_DEBUG << "Engine load index totally cost: " << load_info
// << " percent: " << load_percent*100 << "%";
// ENGINE_LOG_DEBUG << "Engine search index totally cost: " << search_info
// << " percent: " << search_percent*100 << "%";
// ENGINE_LOG_DEBUG << "Engine reduce topk totally cost: " << reduce_info
// << " percent: " << reduce_percent*100 << "%";
// } else {
// ENGINE_LOG_DEBUG << "Engine load cost: " << load_info
// << " search cost: " << search_info
// << " reduce cost: " << reduce_info;
// }
// step 4: construct results
results = job->GetResult(); results = job->GetResult();
rc.ElapseFromBegin("Engine query totally cost"); rc.ElapseFromBegin("Engine query totally cost");
...@@ -701,14 +670,13 @@ DBImpl::BackgroundMergeFiles(const std::string& table_id) { ...@@ -701,14 +670,13 @@ DBImpl::BackgroundMergeFiles(const std::string& table_id) {
return status; return status;
} }
bool has_merge = false;
for (auto& kv : raw_files) { for (auto& kv : raw_files) {
auto files = kv.second; auto files = kv.second;
if (files.size() < options_.merge_trigger_number_) { if (files.size() < options_.merge_trigger_number_) {
ENGINE_LOG_DEBUG << "Files number not greater equal than merge trigger number, skip merge action"; ENGINE_LOG_DEBUG << "Files number not greater equal than merge trigger number, skip merge action";
continue; continue;
} }
has_merge = true;
MergeFiles(table_id, kv.first, kv.second); MergeFiles(table_id, kv.first, kv.second);
if (shutting_down_.load(std::memory_order_acquire)) { if (shutting_down_.load(std::memory_order_acquire)) {
...@@ -776,127 +744,6 @@ DBImpl::StartBuildIndexTask(bool force) { ...@@ -776,127 +744,6 @@ DBImpl::StartBuildIndexTask(bool force) {
} }
} }
Status
DBImpl::BuildIndex(const meta::TableFileSchema& file) {
ExecutionEnginePtr to_index = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_,
(MetricType)file.metric_type_, file.nlist_);
if (to_index == nullptr) {
ENGINE_LOG_ERROR << "Invalid engine type";
return Status(DB_ERROR, "Invalid engine type");
}
try {
// step 1: load index
Status status = to_index->Load(options_.insert_cache_immediately_);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to load index file: " << status.ToString();
return status;
}
// step 2: create table file
meta::TableFileSchema table_file;
table_file.table_id_ = file.table_id_;
table_file.date_ = file.date_;
table_file.file_type_ =
meta::TableFileSchema::NEW_INDEX; // for multi-db-path, distribute index file averagely to each path
status = meta_ptr_->CreateTableFile(table_file);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to create table file: " << status.ToString();
return status;
}
// step 3: build index
std::shared_ptr<ExecutionEngine> index;
try {
server::CollectBuildIndexMetrics metrics;
index = to_index->BuildIndex(table_file.location_, (EngineType)table_file.engine_type_);
if (index == nullptr) {
table_file.file_type_ = meta::TableFileSchema::TO_DELETE;
status = meta_ptr_->UpdateTableFile(table_file);
ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_
<< " to to_delete";
return status;
}
} catch (std::exception& ex) {
// typical error: out of gpu memory
std::string msg = "BuildIndex encounter exception: " + std::string(ex.what());
ENGINE_LOG_ERROR << msg;
table_file.file_type_ = meta::TableFileSchema::TO_DELETE;
status = meta_ptr_->UpdateTableFile(table_file);
ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete";
std::cout << "ERROR: failed to build index, index file is too large or gpu memory is not enough"
<< std::endl;
return Status(DB_ERROR, msg);
}
// step 4: if table has been deleted, dont save index file
bool has_table = false;
meta_ptr_->HasTable(file.table_id_, has_table);
if (!has_table) {
meta_ptr_->DeleteTableFiles(file.table_id_);
return Status::OK();
}
// step 5: save index file
try {
index->Serialize();
} catch (std::exception& ex) {
// typical error: out of disk space or permition denied
std::string msg = "Serialize index encounter exception: " + std::string(ex.what());
ENGINE_LOG_ERROR << msg;
table_file.file_type_ = meta::TableFileSchema::TO_DELETE;
status = meta_ptr_->UpdateTableFile(table_file);
ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete";
std::cout << "ERROR: failed to persist index file: " << table_file.location_
<< ", possible out of disk space" << std::endl;
return Status(DB_ERROR, msg);
}
// step 6: update meta
table_file.file_type_ = meta::TableFileSchema::INDEX;
table_file.file_size_ = index->PhysicalSize();
table_file.row_count_ = index->Count();
auto origin_file = file;
origin_file.file_type_ = meta::TableFileSchema::BACKUP;
meta::TableFilesSchema update_files = {table_file, origin_file};
status = meta_ptr_->UpdateTableFiles(update_files);
if (status.ok()) {
ENGINE_LOG_DEBUG << "New index file " << table_file.file_id_ << " of size " << index->PhysicalSize()
<< " bytes"
<< " from file " << origin_file.file_id_;
if (options_.insert_cache_immediately_) {
index->Cache();
}
} else {
// failed to update meta, mark the new file as to_delete, don't delete old file
origin_file.file_type_ = meta::TableFileSchema::TO_INDEX;
status = meta_ptr_->UpdateTableFile(origin_file);
ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << origin_file.file_id_ << " to to_index";
table_file.file_type_ = meta::TableFileSchema::TO_DELETE;
status = meta_ptr_->UpdateTableFile(table_file);
ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete";
}
} catch (std::exception& ex) {
std::string msg = "Build index encounter exception: " + std::string(ex.what());
ENGINE_LOG_ERROR << msg;
return Status(DB_ERROR, msg);
}
return Status::OK();
}
void void
DBImpl::BackgroundBuildIndex() { DBImpl::BackgroundBuildIndex() {
ENGINE_LOG_TRACE << "Background build index thread start"; ENGINE_LOG_TRACE << "Background build index thread start";
...@@ -921,17 +768,6 @@ DBImpl::BackgroundBuildIndex() { ...@@ -921,17 +768,6 @@ DBImpl::BackgroundBuildIndex() {
ENGINE_LOG_ERROR << "Building index failed: " << status.ToString(); ENGINE_LOG_ERROR << "Building index failed: " << status.ToString();
} }
} }
// for (auto &file : to_index_files) {
// status = BuildIndex(file);
// if (!status.ok()) {
// ENGINE_LOG_ERROR << "Building index for " << file.id_ << " failed: " << status.ToString();
// }
//
// if (shutting_down_.load(std::memory_order_acquire)) {
// ENGINE_LOG_DEBUG << "Server will shutdown, skip build index action";
// break;
// }
// }
ENGINE_LOG_TRACE << "Background build index thread exit"; ENGINE_LOG_TRACE << "Background build index thread exit";
} }
......
...@@ -107,7 +107,7 @@ class DBImpl : public DB { ...@@ -107,7 +107,7 @@ class DBImpl : public DB {
private: private:
Status Status
QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq, QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files, uint64_t k, uint64_t nq,
uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results); uint64_t nprobe, const float* vectors, QueryResults& results);
void void
BackgroundTimerTask(); BackgroundTimerTask();
...@@ -133,9 +133,6 @@ class DBImpl : public DB { ...@@ -133,9 +133,6 @@ class DBImpl : public DB {
void void
BackgroundBuildIndex(); BackgroundBuildIndex();
Status
BuildIndex(const meta::TableFileSchema&);
Status Status
MemSerialize(); MemSerialize();
......
...@@ -88,6 +88,11 @@ function(ExternalProject_Create_Cache project_name package_file install_path cac ...@@ -88,6 +88,11 @@ function(ExternalProject_Create_Cache project_name package_file install_path cac
file(REMOVE ${package_file}) file(REMOVE ${package_file})
endif() endif()
string(REGEX REPLACE "(.+)/.+$" "\\1" package_dir ${package_file})
if(NOT EXISTS ${package_dir})
file(MAKE_DIRECTORY ${package_dir})
endif()
message(STATUS "Will create cached package file: ${package_file}") message(STATUS "Will create cached package file: ${package_file}")
ExternalProject_Add_Step(${project_name} package ExternalProject_Add_Step(${project_name} package
......
...@@ -125,6 +125,10 @@ endif() ...@@ -125,6 +125,10 @@ endif()
if(USE_JFROG_CACHE STREQUAL "ON") if(USE_JFROG_CACHE STREQUAL "ON")
set(JFROG_ARTFACTORY_CACHE_URL "${JFROG_ARTFACTORY_URL}/milvus/thirdparty/cache/${CMAKE_OS_NAME}/${KNOWHERE_BUILD_ARCH}/${BUILD_TYPE}") set(JFROG_ARTFACTORY_CACHE_URL "${JFROG_ARTFACTORY_URL}/milvus/thirdparty/cache/${CMAKE_OS_NAME}/${KNOWHERE_BUILD_ARCH}/${BUILD_TYPE}")
set(THIRDPARTY_PACKAGE_CACHE "${THIRDPARTY_DIR}/cache") set(THIRDPARTY_PACKAGE_CACHE "${THIRDPARTY_DIR}/cache")
if(NOT EXISTS ${THIRDPARTY_PACKAGE_CACHE})
message(STATUS "Will create cached directory: ${THIRDPARTY_PACKAGE_CACHE}")
file(MAKE_DIRECTORY ${THIRDPARTY_PACKAGE_CACHE})
endif()
endif() endif()
macro(resolve_dependency DEPENDENCY_NAME) macro(resolve_dependency DEPENDENCY_NAME)
...@@ -240,6 +244,7 @@ if(CUSTOMIZATION) ...@@ -240,6 +244,7 @@ if(CUSTOMIZATION)
message(STATUS "Check the remote cache file ${FAISS_SOURCE_URL}. return code = ${return_code}") message(STATUS "Check the remote cache file ${FAISS_SOURCE_URL}. return code = ${return_code}")
if (NOT return_code EQUAL 0) if (NOT return_code EQUAL 0)
set(FAISS_SOURCE_URL "https://github.com/facebookresearch/faiss/archive/v1.5.3.tar.gz") set(FAISS_SOURCE_URL "https://github.com/facebookresearch/faiss/archive/v1.5.3.tar.gz")
set(CUSTOMIZATION FALSE PARENT_SCOPE)
endif() endif()
else() else()
set(FAISS_SOURCE_URL "https://github.com/facebookresearch/faiss/archive/v1.5.3.tar.gz") set(FAISS_SOURCE_URL "https://github.com/facebookresearch/faiss/archive/v1.5.3.tar.gz")
......
set(TBB_DIR ${CORE_SOURCE_DIR}/thirdparty/tbb)
set(TBB_LIBRARIES ${TBB_DIR}/libtbb.so)
include_directories(${TBB_DIR}/include)
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64) link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64)
...@@ -60,7 +56,6 @@ set(index_srcs ...@@ -60,7 +56,6 @@ set(index_srcs
set(depend_libs set(depend_libs
SPTAGLibStatic SPTAGLibStatic
${TBB_LIBRARIES}
faiss faiss
openblas openblas
lapack lapack
...@@ -107,13 +102,6 @@ INSTALL(FILES ${OPENBLAS_REAL_STATIC_LIB} ...@@ -107,13 +102,6 @@ INSTALL(FILES ${OPENBLAS_REAL_STATIC_LIB}
DESTINATION lib DESTINATION lib
) )
INSTALL(FILES ${CORE_SOURCE_DIR}/thirdparty/tbb/libtbb.so.2
DESTINATION lib
)
INSTALL(FILES ${CORE_SOURCE_DIR}/thirdparty/tbb/libtbb.so
DESTINATION lib
)
set(CORE_INCLUDE_DIRS set(CORE_INCLUDE_DIRS
${CORE_SOURCE_DIR}/knowhere ${CORE_SOURCE_DIR}/knowhere
${CORE_SOURCE_DIR}/thirdparty ${CORE_SOURCE_DIR}/thirdparty
...@@ -122,7 +110,6 @@ set(CORE_INCLUDE_DIRS ...@@ -122,7 +110,6 @@ set(CORE_INCLUDE_DIRS
${FAISS_INCLUDE_DIR} ${FAISS_INCLUDE_DIR}
${OPENBLAS_INCLUDE_DIR} ${OPENBLAS_INCLUDE_DIR}
${LAPACK_INCLUDE_DIR} ${LAPACK_INCLUDE_DIR}
${CORE_SOURCE_DIR}/thirdparty/tbb/include
) )
set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE) set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
...@@ -132,7 +119,6 @@ set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE) ...@@ -132,7 +119,6 @@ set(CORE_INCLUDE_DIRS ${CORE_INCLUDE_DIRS} PARENT_SCOPE)
# ${ARROW_INCLUDE_DIR}/arrow # ${ARROW_INCLUDE_DIR}/arrow
# ${FAISS_PREFIX}/include/faiss # ${FAISS_PREFIX}/include/faiss
# ${OPENBLAS_INCLUDE_DIR}/ # ${OPENBLAS_INCLUDE_DIR}/
# ${CORE_SOURCE_DIR}/thirdparty/tbb/include/tbb
# DESTINATION # DESTINATION
# include) # include)
# #
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
namespace knowhere { namespace knowhere {
#ifdef CUSTOMIZATION #ifdef CUSTOMIZATION
IndexModelPtr IndexModelPtr
IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) { IVFSQHybrid::Train(const DatasetPtr& dataset, const Config& config) {
auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config); auto build_cfg = std::dynamic_pointer_cast<IVFSQCfg>(config);
......
...@@ -36,42 +36,47 @@ BinarySet ...@@ -36,42 +36,47 @@ BinarySet
CPUKDTRNG::Serialize() { CPUKDTRNG::Serialize() {
std::vector<void*> index_blobs; std::vector<void*> index_blobs;
std::vector<int64_t> index_len; std::vector<int64_t> index_len;
index_ptr_->SaveIndexToMemory(index_blobs, index_len);
// TODO(zirui): dev
// index_ptr_->SaveIndexToMemory(index_blobs, index_len);
BinarySet binary_set; BinarySet binary_set;
auto sample = std::make_shared<uint8_t>(); //
sample.reset(static_cast<uint8_t*>(index_blobs[0])); // auto sample = std::make_shared<uint8_t>();
auto tree = std::make_shared<uint8_t>(); // sample.reset(static_cast<uint8_t*>(index_blobs[0]));
tree.reset(static_cast<uint8_t*>(index_blobs[1])); // auto tree = std::make_shared<uint8_t>();
auto graph = std::make_shared<uint8_t>(); // tree.reset(static_cast<uint8_t*>(index_blobs[1]));
graph.reset(static_cast<uint8_t*>(index_blobs[2])); // auto graph = std::make_shared<uint8_t>();
auto metadata = std::make_shared<uint8_t>(); // graph.reset(static_cast<uint8_t*>(index_blobs[2]));
metadata.reset(static_cast<uint8_t*>(index_blobs[3])); // auto metadata = std::make_shared<uint8_t>();
// metadata.reset(static_cast<uint8_t*>(index_blobs[3]));
binary_set.Append("samples", sample, index_len[0]); //
binary_set.Append("tree", tree, index_len[1]); // binary_set.Append("samples", sample, index_len[0]);
binary_set.Append("graph", graph, index_len[2]); // binary_set.Append("tree", tree, index_len[1]);
binary_set.Append("metadata", metadata, index_len[3]); // binary_set.Append("graph", graph, index_len[2]);
// binary_set.Append("metadata", metadata, index_len[3]);
return binary_set; return binary_set;
} }
void void
CPUKDTRNG::Load(const BinarySet& binary_set) { CPUKDTRNG::Load(const BinarySet& binary_set) {
std::vector<void*> index_blobs; // TODO(zirui): dev
auto samples = binary_set.GetByName("samples"); // std::vector<void*> index_blobs;
index_blobs.push_back(samples->data.get()); //
// auto samples = binary_set.GetByName("samples");
auto tree = binary_set.GetByName("tree"); // index_blobs.push_back(samples->data.get());
index_blobs.push_back(tree->data.get()); //
// auto tree = binary_set.GetByName("tree");
auto graph = binary_set.GetByName("graph"); // index_blobs.push_back(tree->data.get());
index_blobs.push_back(graph->data.get()); //
// auto graph = binary_set.GetByName("graph");
auto metadata = binary_set.GetByName("metadata"); // index_blobs.push_back(graph->data.get());
index_blobs.push_back(metadata->data.get()); //
// auto metadata = binary_set.GetByName("metadata");
index_ptr_->LoadIndexFromMemory(index_blobs); // index_blobs.push_back(metadata->data.get());
//
// index_ptr_->LoadIndexFromMemory(index_blobs);
} }
// PreprocessorPtr // PreprocessorPtr
......
...@@ -89,5 +89,3 @@ dkms.conf ...@@ -89,5 +89,3 @@ dkms.conf
/Wrappers/inc/AnnClient.java /Wrappers/inc/AnnClient.java
/AnnService.users - Copy.props /AnnService.users - Copy.props
/.vs /.vs
Release/
Debug/
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. # Licensed under the MIT License.
file(GLOB HDR_FILES ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h) file(GLOB HDR_FILES ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/VectorSetReaders/*.h)
file(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp) file(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/VectorSetReaders/*.cpp)
include_directories(${PROJECT_SOURCE_DIR}/AnnService) include_directories(${PROJECT_SOURCE_DIR}/AnnService)
add_library (SPTAGLib SHARED ${SRC_FILES} ${HDR_FILES}) add_library (SPTAGLib SHARED ${SRC_FILES} ${HDR_FILES})
target_link_libraries (SPTAGLib ${TBB_LIBRARIES}) target_link_libraries (SPTAGLib)
add_library (SPTAGLibStatic STATIC ${SRC_FILES} ${HDR_FILES}) add_library (SPTAGLibStatic STATIC ${SRC_FILES} ${HDR_FILES})
set_target_properties(SPTAGLibStatic PROPERTIES OUTPUT_NAME SPTAGLib) set_target_properties(SPTAGLibStatic PROPERTIES OUTPUT_NAME SPTAGLib)
file(GLOB SERVER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Server/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) file(GLOB SERVER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Server/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h)
file(GLOB SERVER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Server/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) file(GLOB SERVER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Server/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp)
add_executable (server ${SERVER_FILES} ${SERVER_HDR_FILES}) add_executable (server ${SERVER_FILES} ${SERVER_HDR_FILES})
target_link_libraries(server ${Boost_LIBRARIES} ${TBB_LIBRARIES}) target_link_libraries(server ${Boost_LIBRARIES})
file(GLOB CLIENT_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) file(GLOB CLIENT_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h)
file(GLOB CLIENT_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) file(GLOB CLIENT_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp)
add_executable (client ${CLIENT_FILES} ${CLIENT_HDR_FILES}) add_executable (client ${CLIENT_FILES} ${CLIENT_HDR_FILES})
target_link_libraries(client ${Boost_LIBRARIES} ${TBB_LIBRARIES}) target_link_libraries(client ${Boost_LIBRARIES})
file(GLOB AGG_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Aggregator/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) file(GLOB AGG_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Aggregator/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h)
file(GLOB AGG_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Aggregator/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) file(GLOB AGG_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Aggregator/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp)
add_executable (aggregator ${AGG_FILES} ${AGG_HDR_FILES}) add_executable (aggregator ${AGG_FILES} ${AGG_HDR_FILES})
target_link_libraries(aggregator ${Boost_LIBRARIES} ${TBB_LIBRARIES}) target_link_libraries(aggregator ${Boost_LIBRARIES})
file(GLOB BUILDER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/VectorSetReaders/*.h) file(GLOB BUILDER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/*.h)
file(GLOB BUILDER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/VectorSetReaders/*.cpp) file(GLOB BUILDER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/*.cpp)
add_executable (indexbuilder ${BUILDER_FILES} ${BUILDER_HDR_FILES}) add_executable (indexbuilder ${BUILDER_FILES} ${BUILDER_HDR_FILES})
target_link_libraries(indexbuilder ${Boost_LIBRARIES} ${TBB_LIBRARIES}) target_link_libraries(indexbuilder ${Boost_LIBRARIES})
file(GLOB SEARCHER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexSearcher/*.cpp) file(GLOB SEARCHER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexSearcher/*.cpp)
add_executable (indexsearcher ${SEARCHER_FILES} ${HDR_FILES}) add_executable (indexsearcher ${SEARCHER_FILES} ${HDR_FILES})
target_link_libraries(indexsearcher ${Boost_LIBRARIES} ${TBB_LIBRARIES}) target_link_libraries(indexsearcher ${Boost_LIBRARIES})
install(TARGETS SPTAGLib SPTAGLibStatic server client aggregator indexbuilder indexsearcher install(TARGETS SPTAGLib SPTAGLibStatic server client aggregator indexbuilder indexsearcher
RUNTIME DESTINATION bin RUNTIME DESTINATION bin
ARCHIVE DESTINATION lib ARCHIVE DESTINATION lib
LIBRARY DESTINATION lib) LIBRARY DESTINATION lib)
install(DIRECTORY inc DESTINATION include/sptag
FILES_MATCHING PATTERN "*.h")
\ No newline at end of file
...@@ -149,25 +149,29 @@ ...@@ -149,25 +149,29 @@
<ClInclude Include="inc\Core\DefinitionList.h" /> <ClInclude Include="inc\Core\DefinitionList.h" />
<ClInclude Include="inc\Core\MetadataSet.h" /> <ClInclude Include="inc\Core\MetadataSet.h" />
<ClInclude Include="inc\Core\SearchQuery.h" /> <ClInclude Include="inc\Core\SearchQuery.h" />
<ClInclude Include="inc\Core\SearchResult.h" />
<ClInclude Include="inc\Core\VectorIndex.h" /> <ClInclude Include="inc\Core\VectorIndex.h" />
<ClInclude Include="inc\Core\VectorSet.h" /> <ClInclude Include="inc\Core\VectorSet.h" />
<ClInclude Include="inc\Helper\ArgumentsParser.h" /> <ClInclude Include="inc\Helper\ArgumentsParser.h" />
<ClInclude Include="inc\Helper\Base64Encode.h" /> <ClInclude Include="inc\Helper\Base64Encode.h" />
<ClInclude Include="inc\Helper\BufferStream.h" />
<ClInclude Include="inc\Helper\CommonHelper.h" /> <ClInclude Include="inc\Helper\CommonHelper.h" />
<ClInclude Include="inc\Helper\Concurrent.h" /> <ClInclude Include="inc\Helper\Concurrent.h" />
<ClInclude Include="inc\Helper\ConcurrentSet.h" />
<ClInclude Include="inc\Helper\SimpleIniReader.h" /> <ClInclude Include="inc\Helper\SimpleIniReader.h" />
<ClInclude Include="inc\Helper\StringConvert.h" /> <ClInclude Include="inc\Helper\StringConvert.h" />
<ClInclude Include="inc\Core\Common\NeighborhoodGraph.h" /> <ClInclude Include="inc\Core\Common\NeighborhoodGraph.h" />
<ClInclude Include="inc\Core\Common\RelativeNeighborhoodGraph.h" /> <ClInclude Include="inc\Core\Common\RelativeNeighborhoodGraph.h" />
<ClInclude Include="inc\Core\Common\BKTree.h" /> <ClInclude Include="inc\Core\Common\BKTree.h" />
<ClInclude Include="inc\Core\Common\KDTree.h" /> <ClInclude Include="inc\Core\Common\KDTree.h" />
<ClInclude Include="inc\Helper\VectorSetReader.h" />
<ClInclude Include="inc\Helper\VectorSetReaders\DefaultReader.h" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="src\Core\BKT\BKTIndex.cpp" /> <ClCompile Include="src\Core\BKT\BKTIndex.cpp" />
<ClCompile Include="src\Core\Common\NeighborhoodGraph.cpp" /> <ClCompile Include="src\Core\Common\NeighborhoodGraph.cpp" />
<ClCompile Include="src\Core\KDT\KDTIndex.cpp" /> <ClCompile Include="src\Core\KDT\KDTIndex.cpp" />
<ClCompile Include="src\Core\Common\WorkSpacePool.cpp" /> <ClCompile Include="src\Core\Common\WorkSpacePool.cpp" />
<ClCompile Include="src\Core\CommonDataStructure.cpp" />
<ClCompile Include="src\Core\MetadataSet.cpp" /> <ClCompile Include="src\Core\MetadataSet.cpp" />
<ClCompile Include="src\Core\VectorIndex.cpp" /> <ClCompile Include="src\Core\VectorIndex.cpp" />
<ClCompile Include="src\Core\VectorSet.cpp" /> <ClCompile Include="src\Core\VectorSet.cpp" />
...@@ -176,18 +180,13 @@ ...@@ -176,18 +180,13 @@
<ClCompile Include="src\Helper\CommonHelper.cpp" /> <ClCompile Include="src\Helper\CommonHelper.cpp" />
<ClCompile Include="src\Helper\Concurrent.cpp" /> <ClCompile Include="src\Helper\Concurrent.cpp" />
<ClCompile Include="src\Helper\SimpleIniReader.cpp" /> <ClCompile Include="src\Helper\SimpleIniReader.cpp" />
<ClCompile Include="src\Helper\VectorSetReader.cpp" />
<ClCompile Include="src\Helper\VectorSetReaders\DefaultReader.cpp" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<None Include="packages.config" /> <None Include="packages.config" />
</ItemGroup> </ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets"> <ImportGroup Label="ExtensionTargets">
<Import Project="..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets" Condition="Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" />
</ImportGroup> </ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets'))" />
</Target>
</Project> </Project>
\ No newline at end of file
...@@ -38,6 +38,12 @@ ...@@ -38,6 +38,12 @@
<Filter Include="Source Files\Core\KDT"> <Filter Include="Source Files\Core\KDT">
<UniqueIdentifier>{8fb36afb-73ed-4c3d-8c9b-c3581d80c5d1}</UniqueIdentifier> <UniqueIdentifier>{8fb36afb-73ed-4c3d-8c9b-c3581d80c5d1}</UniqueIdentifier>
</Filter> </Filter>
<Filter Include="Header Files\Helper\VectorSetReaders">
<UniqueIdentifier>{f7bc0bc7-1af5-4870-b8ee-fabdbabdb4c4}</UniqueIdentifier>
</Filter>
<Filter Include="Source Files\Helper\VectorSetReaders">
<UniqueIdentifier>{5c1449e0-38b7-4c82-976e-cbdc488d3fb5}</UniqueIdentifier>
</Filter>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClInclude Include="inc\Core\Common.h"> <ClInclude Include="inc\Core\Common.h">
...@@ -52,6 +58,9 @@ ...@@ -52,6 +58,9 @@
<ClInclude Include="inc\Core\SearchQuery.h"> <ClInclude Include="inc\Core\SearchQuery.h">
<Filter>Header Files\Core</Filter> <Filter>Header Files\Core</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="inc\Core\SearchResult.h">
<Filter>Header Files\Core</Filter>
</ClInclude>
<ClInclude Include="inc\Core\VectorIndex.h"> <ClInclude Include="inc\Core\VectorIndex.h">
<Filter>Header Files\Core</Filter> <Filter>Header Files\Core</Filter>
</ClInclude> </ClInclude>
...@@ -130,11 +139,20 @@ ...@@ -130,11 +139,20 @@
<ClInclude Include="inc\Core\Common\BKTree.h"> <ClInclude Include="inc\Core\Common\BKTree.h">
<Filter>Header Files\Core\Common</Filter> <Filter>Header Files\Core\Common</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="inc\Helper\ConcurrentSet.h">
<Filter>Header Files\Helper</Filter>
</ClInclude>
<ClInclude Include="inc\Helper\BufferStream.h">
<Filter>Header Files\Helper</Filter>
</ClInclude>
<ClInclude Include="inc\Helper\VectorSetReaders\DefaultReader.h">
<Filter>Header Files\Helper\VectorSetReaders</Filter>
</ClInclude>
<ClInclude Include="inc\Helper\VectorSetReader.h">
<Filter>Header Files\Helper</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="src\Core\CommonDataStructure.cpp">
<Filter>Source Files\Core</Filter>
</ClCompile>
<ClCompile Include="src\Core\VectorIndex.cpp"> <ClCompile Include="src\Core\VectorIndex.cpp">
<Filter>Source Files\Core</Filter> <Filter>Source Files\Core</Filter>
</ClCompile> </ClCompile>
...@@ -171,6 +189,12 @@ ...@@ -171,6 +189,12 @@
<ClCompile Include="src\Core\Common\NeighborhoodGraph.cpp"> <ClCompile Include="src\Core\Common\NeighborhoodGraph.cpp">
<Filter>Source Files\Core\Common</Filter> <Filter>Source Files\Core\Common</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="src\Helper\VectorSetReaders\DefaultReader.cpp">
<Filter>Source Files\Helper\VectorSetReaders</Filter>
</ClCompile>
<ClCompile Include="src\Helper\VectorSetReader.cpp">
<Filter>Source Files\Helper</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<None Include="packages.config" /> <None Include="packages.config" />
......
...@@ -139,15 +139,11 @@ ...@@ -139,15 +139,11 @@
<ItemGroup> <ItemGroup>
<ClInclude Include="inc\IndexBuilder\Options.h" /> <ClInclude Include="inc\IndexBuilder\Options.h" />
<ClInclude Include="inc\IndexBuilder\ThreadPool.h" /> <ClInclude Include="inc\IndexBuilder\ThreadPool.h" />
<ClInclude Include="inc\IndexBuilder\VectorSetReader.h" />
<ClInclude Include="inc\IndexBuilder\VectorSetReaders\DefaultReader.h" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="src\IndexBuilder\main.cpp" /> <ClCompile Include="src\IndexBuilder\main.cpp" />
<ClCompile Include="src\IndexBuilder\Options.cpp" /> <ClCompile Include="src\IndexBuilder\Options.cpp" />
<ClCompile Include="src\IndexBuilder\ThreadPool.cpp" /> <ClCompile Include="src\IndexBuilder\ThreadPool.cpp" />
<ClCompile Include="src\IndexBuilder\VectorSetReader.cpp" />
<ClCompile Include="src\IndexBuilder\VectorSetReaders\DefaultReader.cpp" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<None Include="packages.config" /> <None Include="packages.config" />
...@@ -161,7 +157,6 @@ ...@@ -161,7 +157,6 @@
<Import Project="..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets" Condition="Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" /> <Import Project="..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets" Condition="Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" />
<Import Project="..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets" Condition="Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" /> <Import Project="..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets" Condition="Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" />
<Import Project="..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets" Condition="Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" /> <Import Project="..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets" Condition="Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" />
<Import Project="..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets" Condition="Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" />
</ImportGroup> </ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild"> <Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup> <PropertyGroup>
...@@ -174,6 +169,5 @@ ...@@ -174,6 +169,5 @@
<Error Condition="!Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets'))" />
<Error Condition="!Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets'))" />
<Error Condition="!Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets'))" />
<Error Condition="!Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets'))" />
</Target> </Target>
</Project> </Project>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> <Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup> <ItemGroup>
<Filter Include="Source Files"> <Filter Include="Source Files">
...@@ -9,12 +9,6 @@ ...@@ -9,12 +9,6 @@
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier> <UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;hm;inl;inc;xsd</Extensions> <Extensions>h;hh;hpp;hxx;hm;inl;inc;xsd</Extensions>
</Filter> </Filter>
<Filter Include="Header Files\VectorSetReaders">
<UniqueIdentifier>{cf68b421-6a65-44f2-bf43-438b13940d7d}</UniqueIdentifier>
</Filter>
<Filter Include="Source Files\VectorSetReaders">
<UniqueIdentifier>{41ac91f9-6b6d-4341-8791-12f672d6ad5c}</UniqueIdentifier>
</Filter>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClInclude Include="inc\IndexBuilder\Options.h"> <ClInclude Include="inc\IndexBuilder\Options.h">
...@@ -23,27 +17,15 @@ ...@@ -23,27 +17,15 @@
<ClInclude Include="inc\IndexBuilder\ThreadPool.h"> <ClInclude Include="inc\IndexBuilder\ThreadPool.h">
<Filter>Header Files</Filter> <Filter>Header Files</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="inc\IndexBuilder\VectorSetReader.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="inc\IndexBuilder\VectorSetReaders\DefaultReader.h">
<Filter>Header Files\VectorSetReaders</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="src\IndexBuilder\Options.cpp"> <ClCompile Include="src\IndexBuilder\Options.cpp">
<Filter>Source Files</Filter> <Filter>Source Files</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="src\IndexBuilder\ThreadPool.cpp"> <ClCompile Include="src\IndexBuilder\main.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="src\IndexBuilder\VectorSetReader.cpp">
<Filter>Source Files</Filter> <Filter>Source Files</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="src\IndexBuilder\VectorSetReaders\DefaultReader.cpp"> <ClCompile Include="src\IndexBuilder\ThreadPool.cpp">
<Filter>Source Files\VectorSetReaders</Filter>
</ClCompile>
<ClCompile Include="src\IndexBuilder\main.cpp">
<Filter>Source Files</Filter> <Filter>Source Files</Filter>
</ClCompile> </ClCompile>
</ItemGroup> </ItemGroup>
......
...@@ -154,7 +154,6 @@ ...@@ -154,7 +154,6 @@
<Import Project="..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets" Condition="Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" /> <Import Project="..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets" Condition="Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" />
<Import Project="..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets" Condition="Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" /> <Import Project="..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets" Condition="Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" />
<Import Project="..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets" Condition="Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" /> <Import Project="..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets" Condition="Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" />
<Import Project="..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets" Condition="Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" />
</ImportGroup> </ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild"> <Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup> <PropertyGroup>
...@@ -167,6 +166,5 @@ ...@@ -167,6 +166,5 @@
<Error Condition="!Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets'))" />
<Error Condition="!Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets'))" />
<Error Condition="!Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets'))" />
<Error Condition="!Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets'))" />
</Target> </Target>
</Project> </Project>
\ No newline at end of file
...@@ -137,7 +137,6 @@ ...@@ -137,7 +137,6 @@
<Import Project="..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets" Condition="Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" /> <Import Project="..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets" Condition="Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" />
<Import Project="..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets" Condition="Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" /> <Import Project="..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets" Condition="Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" />
<Import Project="..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets" Condition="Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" /> <Import Project="..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets" Condition="Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" />
<Import Project="..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets" Condition="Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" />
</ImportGroup> </ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild"> <Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup> <PropertyGroup>
...@@ -150,6 +149,5 @@ ...@@ -150,6 +149,5 @@
<Error Condition="!Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_system-vc140.1.67.0.0\build\boost_system-vc140.targets'))" />
<Error Condition="!Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_thread-vc140.1.67.0.0\build\boost_thread-vc140.targets'))" />
<Error Condition="!Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets'))" /> <Error Condition="!Exists('..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\boost_wserialization-vc140.1.67.0.0\build\boost_wserialization-vc140.targets'))" />
<Error Condition="!Exists('..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\tbb_oss.9.107.0.0\build\native\tbb_oss.targets'))" />
</Target> </Target>
</Project> </Project>
\ No newline at end of file
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
#include "../Common/WorkSpacePool.h" #include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h" #include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/BKTree.h" #include "../Common/BKTree.h"
#include "inc/Helper/ConcurrentSet.h"
#include "inc/Helper/SimpleIniReader.h" #include "inc/Helper/SimpleIniReader.h"
#include "inc/Helper/StringConvert.h" #include "inc/Helper/StringConvert.h"
#include <functional> #include <functional>
#include <mutex> #include <mutex>
#include <tbb/concurrent_unordered_set.h>
namespace SPTAG namespace SPTAG
{ {
...@@ -48,35 +48,38 @@ namespace SPTAG ...@@ -48,35 +48,38 @@ namespace SPTAG
std::string m_sBKTFilename; std::string m_sBKTFilename;
std::string m_sGraphFilename; std::string m_sGraphFilename;
std::string m_sDataPointsFilename; std::string m_sDataPointsFilename;
std::string m_sDeleteDataPointsFilename;
std::mutex m_dataLock; // protect data and graph std::mutex m_dataAddLock; // protect data and graph
tbb::concurrent_unordered_set<int> m_deletedID; Helper::Concurrent::ConcurrentSet<SizeType> m_deletedID;
float m_fDeletePercentageForRefine;
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool; std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
int m_iNumberOfThreads; int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod; DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length); float(*m_fComputeDistance)(const T* pX, const T* pY, DimensionType length);
int m_iMaxCheck; int m_iMaxCheck;
int m_iThresholdOfNumberOfContinuousNoBetterPropagation; int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
int m_iNumberOfInitialDynamicPivots; int m_iNumberOfInitialDynamicPivots;
int m_iNumberOfOtherDynamicPivots; int m_iNumberOfOtherDynamicPivots;
public: public:
Index() Index()
{ {
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ #define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
VarName = DefaultValue; \ VarName = DefaultValue; \
#include "inc/Core/BKT/ParameterDefinitionList.h" #include "inc/Core/BKT/ParameterDefinitionList.h"
#undef DefineBKTParameter #undef DefineBKTParameter
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod); m_pSamples.SetName("Vector");
} m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
}
~Index() {} ~Index() {}
inline int GetNumSamples() const { return m_pSamples.R(); } inline SizeType GetNumSamples() const { return m_pSamples.R(); }
inline int GetFeatureDim() const { return m_pSamples.C(); } inline DimensionType GetFeatureDim() const { return m_pSamples.C(); }
inline int GetCurrMaxCheck() const { return m_iMaxCheck; } inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
inline int GetNumThreads() const { return m_iNumberOfThreads; } inline int GetNumThreads() const { return m_iNumberOfThreads; }
...@@ -85,25 +88,41 @@ namespace SPTAG ...@@ -85,25 +88,41 @@ namespace SPTAG
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); } inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); } inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; } inline const void* GetSample(const SizeType idx) const { return (void*)m_pSamples[idx]; }
inline bool ContainSample(const SizeType idx) const { return !m_deletedID.contains(idx); }
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension); inline bool NeedRefine() const { return m_deletedID.size() >= (size_t)(GetNumSamples() * m_fDeletePercentageForRefine); }
std::shared_ptr<std::vector<std::uint64_t>> BufferSize() const
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen); {
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs); std::shared_ptr<std::vector<std::uint64_t>> buffersize(new std::vector<std::uint64_t>);
buffersize->push_back(m_pSamples.BufferSize());
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout); buffersize->push_back(m_pTrees.BufferSize());
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader); buffersize->push_back(m_pGraph.BufferSize());
buffersize->push_back(m_deletedID.bufferSize());
return std::move(buffersize);
}
ErrorCode SaveConfig(std::ostream& p_configout) const;
ErrorCode SaveIndexData(const std::string& p_folderPath);
ErrorCode SaveIndexData(const std::vector<std::ostream*>& p_indexStreams);
ErrorCode LoadConfig(Helper::IniReader& p_reader);
ErrorCode LoadIndexData(const std::string& p_folderPath);
ErrorCode LoadIndexDataFromMemory(const std::vector<ByteArray>& p_indexBlobs);
ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension);
ErrorCode SearchIndex(QueryResult &p_query) const; ErrorCode SearchIndex(QueryResult &p_query) const;
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension); ErrorCode AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start = nullptr);
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum); ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum);
ErrorCode DeleteIndex(const SizeType& p_id);
ErrorCode SetParameter(const char* p_param, const char* p_value); ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const; std::string GetParameter(const char* p_param) const;
private:
ErrorCode RefineIndex(const std::string& p_folderPath); ErrorCode RefineIndex(const std::string& p_folderPath);
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const; ErrorCode RefineIndex(const std::vector<std::ostream*>& p_indexStreams);
private:
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet<SizeType> &p_deleted) const;
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const; void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
}; };
} // namespace BKT } // namespace BKT
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
DefineBKTParameter(m_sBKTFilename, std::string, std::string("tree.bin"), "TreeFilePath") DefineBKTParameter(m_sBKTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
DefineBKTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath") DefineBKTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
DefineBKTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath") DefineBKTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
DefineBKTParameter(m_sDeleteDataPointsFilename, std::string, std::string("deletes.bin"), "DeleteVectorFilePath")
DefineBKTParameter(m_pTrees.m_iTreeNumber, int, 1L, "BKTNumber") DefineBKTParameter(m_pTrees.m_iTreeNumber, int, 1L, "BKTNumber")
DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK") DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK")
...@@ -14,11 +15,11 @@ DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize") ...@@ -14,11 +15,11 @@ DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize")
DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples") DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples")
DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TpTreeNumber") DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber")
DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize") DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
DefineBKTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTpTreeSplit") DefineBKTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTpTreeSplit")
DefineBKTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize") DefineBKTParameter(m_pGraph.m_iNeighborhoodSize, DimensionType, 32L, "NeighborhoodSize")
DefineBKTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale") DefineBKTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
DefineBKTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale") DefineBKTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
DefineBKTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations") DefineBKTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
...@@ -28,6 +29,7 @@ DefineBKTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckFor ...@@ -28,6 +29,7 @@ DefineBKTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckFor
DefineBKTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads") DefineBKTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
DefineBKTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod") DefineBKTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
DefineBKTParameter(m_fDeletePercentageForRefine, float, 0.4F, "DeletePercentageForRefine")
DefineBKTParameter(m_iMaxCheck, int, 8192L, "MaxCheck") DefineBKTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
DefineBKTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation") DefineBKTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
DefineBKTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots") DefineBKTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
......
...@@ -56,9 +56,10 @@ inline bool fileexists(const char* path) { ...@@ -56,9 +56,10 @@ inline bool fileexists(const char* path) {
namespace SPTAG namespace SPTAG
{ {
typedef std::int32_t SizeType;
typedef std::int32_t DimensionType;
typedef std::uint32_t SizeType; const SizeType MaxSize = (std::numeric_limits<SizeType>::max)();
const float MinDist = (std::numeric_limits<float>::min)(); const float MinDist = (std::numeric_limits<float>::min)();
const float MaxDist = (std::numeric_limits<float>::max)(); const float MaxDist = (std::numeric_limits<float>::max)();
const float Epsilon = 0.000000001f; const float Epsilon = 0.000000001f;
...@@ -76,11 +77,6 @@ public: ...@@ -76,11 +77,6 @@ public:
#endif #endif
}; };
// Type of number index.
typedef std::int32_t IndexType;
static_assert(std::is_integral<IndexType>::value, "IndexType must be integral type.");
enum class ErrorCode : std::uint16_t enum class ErrorCode : std::uint16_t
{ {
#define DefineErrorCode(Name, Value) Name = Value, #define DefineErrorCode(Name, Value) Name = Value,
......
...@@ -36,9 +36,9 @@ namespace SPTAG ...@@ -36,9 +36,9 @@ namespace SPTAG
{ {
class Utils { class Utils {
public: public:
static int rand_int(int high = RAND_MAX, int low = 0) // Generates a random int value. static SizeType rand(SizeType high = MaxSize, SizeType low = 0) // Generates a random int value.
{ {
return low + (int)(float(high - low)*(std::rand() / (RAND_MAX + 1.0))); return low + (SizeType)(float(high - low)*(std::rand() / (RAND_MAX + 1.0)));
} }
static inline float atomic_float_add(volatile float* ptr, const float operand) static inline float atomic_float_add(volatile float* ptr, const float operand)
...@@ -61,11 +61,11 @@ namespace SPTAG ...@@ -61,11 +61,11 @@ namespace SPTAG
} }
} }
static double GetVector(char* cstr, const char* sep, std::vector<float>& arr, int& NumDim) { static double GetVector(char* cstr, const char* sep, std::vector<float>& arr, DimensionType& NumDim) {
char* current; char* current;
char* context = NULL; char* context = NULL;
int i = 0; DimensionType i = 0;
double sum = 0; double sum = 0;
arr.clear(); arr.clear();
current = strtok_s(cstr, sep, &context); current = strtok_s(cstr, sep, &context);
...@@ -90,23 +90,23 @@ namespace SPTAG ...@@ -90,23 +90,23 @@ namespace SPTAG
} }
template <typename T> template <typename T>
static void Normalize(T* arr, int col, int base) { static void Normalize(T* arr, DimensionType col, int base) {
double vecLen = 0; double vecLen = 0;
for (int j = 0; j < col; j++) { for (DimensionType j = 0; j < col; j++) {
double val = arr[j]; double val = arr[j];
vecLen += val * val; vecLen += val * val;
} }
vecLen = std::sqrt(vecLen); vecLen = std::sqrt(vecLen);
if (vecLen < 1e-6) { if (vecLen < 1e-6) {
T val = (T)(1.0 / std::sqrt((double)col) * base); T val = (T)(1.0 / std::sqrt((double)col) * base);
for (int j = 0; j < col; j++) arr[j] = val; for (DimensionType j = 0; j < col; j++) arr[j] = val;
} }
else { else {
for (int j = 0; j < col; j++) arr[j] = (T)(arr[j] / vecLen * base); for (DimensionType j = 0; j < col; j++) arr[j] = (T)(arr[j] / vecLen * base);
} }
} }
static size_t ProcessLine(std::string& currentLine, std::vector<float>& arr, int& D, int base, DistCalcMethod distCalcMethod) { static size_t ProcessLine(std::string& currentLine, std::vector<float>& arr, DimensionType& D, int base, DistCalcMethod distCalcMethod) {
size_t index; size_t index;
double vecLen; double vecLen;
if (currentLine.length() == 0 || (index = currentLine.find_last_of("\t")) == std::string::npos || (vecLen = GetVector(const_cast<char*>(currentLine.c_str() + index + 1), "|", arr, D)) < -1) { if (currentLine.length() == 0 || (index = currentLine.find_last_of("\t")) == std::string::npos || (vecLen = GetVector(const_cast<char*>(currentLine.c_str() + index + 1), "|", arr, D)) < -1) {
...@@ -121,10 +121,10 @@ namespace SPTAG ...@@ -121,10 +121,10 @@ namespace SPTAG
} }
template <typename T> template <typename T>
static void PrepareQuerys(std::ifstream& inStream, std::vector<std::string>& qString, std::vector<std::vector<T>>& Query, int& NumQuery, int& NumDim, DistCalcMethod distCalcMethod, int base) { static void PrepareQuerys(std::ifstream& inStream, std::vector<std::string>& qString, std::vector<std::vector<T>>& Query, SizeType& NumQuery, DimensionType& NumDim, DistCalcMethod distCalcMethod, int base) {
std::string currentLine; std::string currentLine;
std::vector<float> arr; std::vector<float> arr;
int i = 0; SizeType i = 0;
size_t index; size_t index;
while ((NumQuery < 0 || i < NumQuery) && !inStream.eof()) { while ((NumQuery < 0 || i < NumQuery) && !inStream.eof()) {
std::getline(inStream, currentLine); std::getline(inStream, currentLine);
...@@ -132,9 +132,9 @@ namespace SPTAG ...@@ -132,9 +132,9 @@ namespace SPTAG
continue; continue;
} }
qString.push_back(currentLine.substr(0, index)); qString.push_back(currentLine.substr(0, index));
if (Query.size() < i + 1) Query.push_back(std::vector<T>(NumDim, 0)); if ((SizeType)Query.size() < i + 1) Query.push_back(std::vector<T>(NumDim, 0));
for (int j = 0; j < NumDim; j++) Query[i][j] = (T)arr[j]; for (DimensionType j = 0; j < NumDim; j++) Query[i][j] = (T)arr[j];
i++; i++;
} }
NumQuery = i; NumQuery = i;
...@@ -149,12 +149,12 @@ namespace SPTAG ...@@ -149,12 +149,12 @@ namespace SPTAG
return 1; return 1;
} }
static inline void AddNeighbor(int idx, float dist, int *neighbors, float *dists, int size) static inline void AddNeighbor(SizeType idx, float dist, SizeType *neighbors, float *dists, DimensionType size)
{ {
size--; size--;
if (dist < dists[size] || (dist == dists[size] && idx < neighbors[size])) if (dist < dists[size] || (dist == dists[size] && idx < neighbors[size]))
{ {
int nb; DimensionType nb;
for (nb = 0; nb <= size && neighbors[nb] != idx; nb++); for (nb = 0; nb <= size && neighbors[nb] != idx; nb++);
if (nb > size) if (nb > size)
......
...@@ -13,158 +13,18 @@ namespace SPTAG ...@@ -13,158 +13,18 @@ namespace SPTAG
{ {
namespace COMMON namespace COMMON
{ {
const int bufsize = 1024 * 1024 * 1024; const int bufsize = 1 << 30;
class DataUtils { class DataUtils {
public: public:
template <typename T>
static void ProcessTSVData(int id, int threadbase, std::uint64_t blocksize,
std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) {
std::ifstream inputStream(filename);
if (!inputStream.is_open()) {
std::cerr << "unable to open file " + filename << std::endl;
throw MyException("unable to open file " + filename);
exit(1);
}
std::ofstream outputStream, metaStream_out, metaStream_index;
outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary);
if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) {
std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl;
throw MyException("unable to open output files");
exit(1);
}
std::vector<float> arr;
std::vector<T> sample;
int base = 1;
if (distCalcMethod == DistCalcMethod::Cosine) {
base = Utils::GetBase<T>();
}
std::uint64_t writepos = 0;
int sampleSize = 0;
std::uint64_t totalread = 0;
std::streamoff startpos = id * blocksize;
#ifndef _MSC_VER
int enter_size = 1;
#else
int enter_size = 1;
#endif
std::string currentLine;
size_t index;
inputStream.seekg(startpos, std::ifstream::beg);
if (id != 0) {
std::getline(inputStream, currentLine);
totalread += currentLine.length() + enter_size;
}
std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl;
while (!inputStream.eof() && totalread <= blocksize) {
std::getline(inputStream, currentLine);
if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) {
totalread += currentLine.length() + enter_size;
continue;
}
sample.resize(D);
for (int j = 0; j < D; j++) sample[j] = (T)arr[j];
outputStream.write((char *)(sample.data()), sizeof(T)*D);
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_out.write(currentLine.c_str(), index);
writepos += index;
sampleSize += 1;
totalread += currentLine.length() + enter_size;
}
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_index.write((char *)&sampleSize, sizeof(int));
inputStream.close();
outputStream.close();
metaStream_out.close();
metaStream_index.close();
numSamples.fetch_add(sampleSize);
std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl;
}
static void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int D) {
std::ifstream inputStream;
std::ofstream outputStream;
char * buf = new char[bufsize];
std::uint64_t * offsets;
int partSamples;
int metaSamples = 0;
std::uint64_t lastoff = 0;
outputStream.open(outfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
outputStream.write((char *)&D, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
outputStream.open(outmetafile, std::ofstream::binary);
for (int i = 0; i < threadbase; i++) {
std::string file = outmetafile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
delete[] buf;
outputStream.open(outmetaindexfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outmetaindexfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
inputStream.seekg(-((long long)sizeof(int)), inputStream.end);
inputStream.read((char *)&partSamples, sizeof(int));
offsets = new std::uint64_t[partSamples + 1];
inputStream.seekg(0, inputStream.beg);
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1));
inputStream.close();
remove(file.c_str());
for (int j = 0; j < partSamples + 1; j++)
offsets[j] += lastoff;
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples);
lastoff = offsets[partSamples];
metaSamples += partSamples;
delete[] offsets;
}
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
outputStream.close();
std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl;
}
static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1, static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1,
const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) { const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) {
std::ifstream inputStream1, inputStream2; std::ifstream inputStream1, inputStream2;
std::ofstream outputStream; std::ofstream outputStream;
char * buf = new char[bufsize]; std::unique_ptr<char[]> bufferHolder(new char[bufsize]);
int R1, R2, C1, C2; char * buf = bufferHolder.get();
SizeType R1, R2;
DimensionType C1, C2;
#define MergeVector(inputStream, vectorFile, R, C) \ #define MergeVector(inputStream, vectorFile, R, C) \
inputStream.open(vectorFile, std::ifstream::binary); \ inputStream.open(vectorFile, std::ifstream::binary); \
...@@ -172,8 +32,8 @@ namespace SPTAG ...@@ -172,8 +32,8 @@ namespace SPTAG
std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \ std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \
return false; \ return false; \
} \ } \
inputStream.read((char *)&(R), sizeof(int)); \ inputStream.read((char *)&(R), sizeof(SizeType)); \
inputStream.read((char *)&(C), sizeof(int)); \ inputStream.read((char *)&(C), sizeof(DimensionType)); \
MergeVector(inputStream1, p_vectorfile1, R1, C1) MergeVector(inputStream1, p_vectorfile1, R1, C1)
MergeVector(inputStream2, p_vectorfile2, R2, C2) MergeVector(inputStream2, p_vectorfile2, R2, C2)
...@@ -185,8 +45,8 @@ namespace SPTAG ...@@ -185,8 +45,8 @@ namespace SPTAG
} }
R1 += R2; R1 += R2;
outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary); outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int)); outputStream.write((char *)&R1, sizeof(SizeType));
outputStream.write((char *)&C1, sizeof(int)); outputStream.write((char *)&C1, sizeof(DimensionType));
while (!inputStream1.eof()) { while (!inputStream1.eof()) {
inputStream1.read(buf, bufsize); inputStream1.read(buf, bufsize);
outputStream.write(buf, inputStream1.gcount()); outputStream.write(buf, inputStream1.gcount());
...@@ -218,26 +78,22 @@ namespace SPTAG ...@@ -218,26 +78,22 @@ namespace SPTAG
outputStream.close(); outputStream.close();
delete[] buf; delete[] buf;
std::uint64_t * offsets = reinterpret_cast<std::uint64_t*>(buf);
std::uint64_t * offsets;
int partSamples;
std::uint64_t lastoff = 0; std::uint64_t lastoff = 0;
outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary); outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int)); outputStream.write((char *)&R1, sizeof(SizeType));
#define MergeMetaIndex(inputStream, metaIndexFile) \ #define MergeMetaIndex(inputStream, metaIndexFile) \
inputStream.open(metaIndexFile, std::ifstream::binary); \ inputStream.open(metaIndexFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \ if (!inputStream.is_open()) { \
std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \ std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \
return false; \ return false; \
} \ } \
inputStream.read((char *)&partSamples, sizeof(int)); \ inputStream.read((char *)&R2, sizeof(SizeType)); \
offsets = new std::uint64_t[partSamples + 1]; \ inputStream.read((char *)offsets, sizeof(std::uint64_t)*(R2 + 1)); \
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); \
inputStream.close(); \ inputStream.close(); \
for (int j = 0; j < partSamples + 1; j++) offsets[j] += lastoff; \ for (SizeType j = 0; j < R2 + 1; j++) offsets[j] += lastoff; \
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); \ outputStream.write((char *)offsets, sizeof(std::uint64_t)*R2); \
lastoff = offsets[partSamples]; \ lastoff = offsets[R2]; \
delete[] offsets; \
MergeMetaIndex(inputStream1, p_metaindexfile1) MergeMetaIndex(inputStream1, p_metaindexfile1)
MergeMetaIndex(inputStream2, p_metaindexfile2) MergeMetaIndex(inputStream2, p_metaindexfile2)
...@@ -253,36 +109,6 @@ namespace SPTAG ...@@ -253,36 +109,6 @@ namespace SPTAG
std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl; std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl;
return true; return true;
} }
template <typename T>
static void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
int threadnum, DistCalcMethod distCalcMethod) {
omp_set_num_threads(threadnum);
std::atomic_int numSamples = { 0 };
int D = -1;
int threadbase = 0;
std::vector<std::string> inputFileNames = Helper::StrUtils::SplitString(filenames, ",");
for (std::string inputFileName : inputFileNames)
{
#ifndef _MSC_VER
struct stat stat_buf;
stat(inputFileName.c_str(), &stat_buf);
#else
struct _stat64 stat_buf;
int res = _stat64(inputFileName.c_str(), &stat_buf);
#endif
std::uint64_t blocksize = (stat_buf.st_size + threadnum - 1) / threadnum;
#pragma omp parallel for
for (int i = 0; i < threadnum; i++) {
ProcessTSVData<T>(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod);
}
threadbase += threadnum;
}
MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D);
}
}; };
} }
} }
......
...@@ -28,23 +28,31 @@ namespace SPTAG ...@@ -28,23 +28,31 @@ namespace SPTAG
class Dataset class Dataset
{ {
private: private:
int rows; std::string name = "Data";
int cols; SizeType rows = 0;
DimensionType cols = 1;
bool ownData = false; bool ownData = false;
T* data = nullptr; T* data = nullptr;
std::vector<T> dataIncremental; SizeType incRows = 0;
std::vector<T*> incBlocks;
static const SizeType rowsInBlock = 1024 * 1024;
public: public:
Dataset(): rows(0), cols(1) {} Dataset()
Dataset(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true) {
incBlocks.reserve(MaxSize / rowsInBlock + 1);
}
Dataset(SizeType rows_, DimensionType cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{ {
Initialize(rows_, cols_, data_, transferOnwership_); Initialize(rows_, cols_, data_, transferOnwership_);
incBlocks.reserve(MaxSize / rowsInBlock + 1);
} }
~Dataset() ~Dataset()
{ {
if (ownData) aligned_free(data); if (ownData) aligned_free(data);
for (T* ptr : incBlocks) aligned_free(ptr);
incBlocks.clear();
} }
void Initialize(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true) void Initialize(SizeType rows_, DimensionType cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{ {
rows = rows_; rows = rows_;
cols = cols_; cols = cols_;
...@@ -52,161 +60,166 @@ namespace SPTAG ...@@ -52,161 +60,166 @@ namespace SPTAG
if (data_ == nullptr || !transferOnwership_) if (data_ == nullptr || !transferOnwership_)
{ {
ownData = true; ownData = true;
data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN); data = (T*)aligned_malloc(((size_t)rows) * cols * sizeof(T), ALIGN);
if (data_ != nullptr) memcpy(data, data_, rows * cols * sizeof(T)); if (data_ != nullptr) memcpy(data, data_, ((size_t)rows) * cols * sizeof(T));
else std::memset(data, -1, rows * cols * sizeof(T)); else std::memset(data, -1, ((size_t)rows) * cols * sizeof(T));
} }
} }
void SetR(int R_) void SetName(const std::string name_) { name = name_; }
void SetR(SizeType R_)
{ {
if (R_ >= rows) if (R_ >= rows)
dataIncremental.resize((R_ - rows) * cols); incRows = R_ - rows;
else else
{ {
rows = R_; rows = R_;
dataIncremental.clear(); incRows = 0;
} }
} }
inline int R() const { return (int)(rows + dataIncremental.size() / cols); } inline SizeType R() const { return rows + incRows; }
inline int C() const { return cols; } inline DimensionType C() const { return cols; }
T* operator[](int index) inline std::uint64_t BufferSize() const { return sizeof(SizeType) + sizeof(DimensionType) + sizeof(T) * R() * C(); }
inline const T* At(SizeType index) const
{ {
if (index >= rows) { if (index >= rows) {
return dataIncremental.data() + (size_t)(index - rows)*cols; SizeType incIndex = index - rows;
return incBlocks[incIndex / rowsInBlock] + ((size_t)(incIndex % rowsInBlock)) * cols;
} }
return data + (size_t)index*cols; return data + ((size_t)index) * cols;
} }
const T* operator[](int index) const T* operator[](SizeType index)
{ {
if (index >= rows) { return (T*)At(index);
return dataIncremental.data() + (size_t)(index - rows)*cols;
}
return data + (size_t)index*cols;
} }
void AddBatch(const T* pData, int num) const T* operator[](SizeType index) const
{ {
dataIncremental.insert(dataIncremental.end(), pData, pData + num*cols); return At(index);
} }
void AddBatch(int num) ErrorCode AddBatch(const T* pData, SizeType num)
{ {
dataIncremental.insert(dataIncremental.end(), (size_t)num*cols, T(-1)); if (R() > MaxSize - num) return ErrorCode::MemoryOverFlow;
SizeType written = 0;
while (written < num) {
SizeType curBlockIdx = (incRows + written) / rowsInBlock;
if (curBlockIdx >= (SizeType)incBlocks.size()) {
T* newBlock = (T*)aligned_malloc(((size_t)rowsInBlock) * cols * sizeof(T), ALIGN);
if (newBlock == nullptr) return ErrorCode::MemoryOverFlow;
incBlocks.push_back(newBlock);
}
SizeType curBlockPos = (incRows + written) % rowsInBlock;
SizeType toWrite = min(rowsInBlock - curBlockPos, num - written);
std::memcpy(incBlocks[curBlockIdx] + ((size_t)curBlockPos) * cols, pData + ((size_t)written) * cols, ((size_t)toWrite) * cols * sizeof(T));
written += toWrite;
}
incRows += written;
return ErrorCode::Success;
} }
bool Save(std::string sDataPointsFileName) ErrorCode AddBatch(SizeType num)
{ {
std::cout << "Save Data To " << sDataPointsFileName << std::endl; if (R() > MaxSize - num) return ErrorCode::MemoryOverFlow;
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
if (fp == NULL) return false; SizeType written = 0;
while (written < num) {
int CR = R(); SizeType curBlockIdx = (incRows + written) / rowsInBlock;
fwrite(&CR, sizeof(int), 1, fp); if (curBlockIdx >= (SizeType)incBlocks.size()) {
fwrite(&cols, sizeof(int), 1, fp); T* newBlock = (T*)aligned_malloc(((size_t)rowsInBlock) * cols * sizeof(T), ALIGN);
if (newBlock == nullptr) return ErrorCode::MemoryOverFlow;
T* ptr = data; incBlocks.push_back(newBlock);
int toWrite = rows; }
while (toWrite > 0) SizeType curBlockPos = (incRows + written) % rowsInBlock;
{ SizeType toWrite = min(rowsInBlock - curBlockPos, num - written);
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp); std::memset(incBlocks[curBlockIdx] + ((size_t)curBlockPos) * cols, -1, ((size_t)toWrite) * cols * sizeof(T));
ptr += write * cols; written += toWrite;
toWrite -= (int)write;
}
ptr = dataIncremental.data();
toWrite = CR - rows;
while (toWrite > 0)
{
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
ptr += write * cols;
toWrite -= (int)write;
} }
fclose(fp); incRows += written;
return ErrorCode::Success;
}
std::cout << "Save Data (" << CR << ", " << cols << ") Finish!" << std::endl; bool Save(std::ostream& p_outstream) const
{
SizeType CR = R();
p_outstream.write((char*)&CR, sizeof(SizeType));
p_outstream.write((char*)&cols, sizeof(DimensionType));
p_outstream.write((char*)data, sizeof(T) * cols * rows);
SizeType blocks = incRows / rowsInBlock;
for (int i = 0; i < blocks; i++)
p_outstream.write((char*)incBlocks[i], sizeof(T) * cols * rowsInBlock);
SizeType remain = incRows % rowsInBlock;
if (remain > 0) p_outstream.write((char*)incBlocks[blocks], sizeof(T) * cols * remain);
std::cout << "Save " << name << " (" << CR << ", " << cols << ") Finish!" << std::endl;
return true; return true;
} }
bool Save(void **pDataPointsMemFile, int64_t &len) bool Save(std::string sDataPointsFileName) const
{ {
size_t size = sizeof(int) + sizeof(int) + sizeof(T) * R() *cols; std::cout << "Save " << name << " To " << sDataPointsFileName << std::endl;
char *mem = (char*)malloc(size); std::ofstream output(sDataPointsFileName, std::ios::binary);
if (mem == NULL) return false; if (!output.is_open()) return false;
Save(output);
int CR = R(); output.close();
auto header = (int*)mem;
header[0] = CR;
header[1] = cols;
auto body = &mem[8];
memcpy(body, data, sizeof(T) * cols * rows);
body += sizeof(T) * cols * rows;
memcpy(body, dataIncremental.data(), sizeof(T) * cols * (CR - rows));
body += sizeof(T) * cols * (CR - rows);
*pDataPointsMemFile = mem;
len = size;
return true; return true;
} }
bool Load(std::string sDataPointsFileName) bool Load(std::string sDataPointsFileName)
{ {
std::cout << "Load Data From " << sDataPointsFileName << std::endl; std::cout << "Load " << name << " From " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "rb"); std::ifstream input(sDataPointsFileName, std::ios::binary);
if (fp == NULL) return false; if (!input.is_open()) return false;
int R, C; input.read((char*)&rows, sizeof(SizeType));
fread(&R, sizeof(int), 1, fp); input.read((char*)&cols, sizeof(DimensionType));
fread(&C, sizeof(int), 1, fp);
Initialize(R, C); Initialize(rows, cols);
T* ptr = data; input.read((char*)data, sizeof(T) * cols * rows);
while (R > 0) { input.close();
size_t read = fread(ptr, sizeof(T) * C, R, fp); std::cout << "Load " << name << " (" << rows << ", " << cols << ") Finish!" << std::endl;
ptr += read * C;
R -= (int)read;
}
fclose(fp);
std::cout << "Load Data (" << rows << ", " << cols << ") Finish!" << std::endl;
return true; return true;
} }
// Functions for loading models from memory mapped files // Functions for loading models from memory mapped files
bool Load(char* pDataPointsMemFile) bool Load(char* pDataPointsMemFile)
{ {
int R, C; SizeType R;
R = *((int*)pDataPointsMemFile); DimensionType C;
pDataPointsMemFile += sizeof(int); R = *((SizeType*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(SizeType);
C = *((int*)pDataPointsMemFile); C = *((DimensionType*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(int); pDataPointsMemFile += sizeof(DimensionType);
Initialize(R, C, (T*)pDataPointsMemFile); Initialize(R, C, (T*)pDataPointsMemFile);
std::cout << "Load " << name << " (" << R << ", " << C << ") Finish!" << std::endl;
return true; return true;
} }
bool Refine(const std::vector<int>& indices, std::string sDataPointsFileName) bool Refine(const std::vector<SizeType>& indices, std::ostream& output)
{ {
std::cout << "Save Refine Data To " << sDataPointsFileName << std::endl; SizeType R = (SizeType)(indices.size());
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb"); output.write((char*)&R, sizeof(SizeType));
if (fp == NULL) return false; output.write((char*)&cols, sizeof(DimensionType));
int R = (int)(indices.size());
fwrite(&R, sizeof(int), 1, fp);
fwrite(&cols, sizeof(int), 1, fp);
// write point one by one in case for cache miss for (SizeType i = 0; i < R; i++) {
for (int i = 0; i < R; i++) { output.write((char*)At(indices[i]), sizeof(T) * cols);
if (indices[i] < rows)
fwrite(data + (size_t)indices[i] * cols, sizeof(T) * cols, 1, fp);
else
fwrite(dataIncremental.data() + (size_t)(indices[i] - rows) * cols, sizeof(T) * cols, 1, fp);
} }
fclose(fp); std::cout << "Save Refine " << name << " (" << R << ", " << cols << ") Finish!" << std::endl;
return true;
}
std::cout << "Save Refine Data (" << R << ", " << cols << ") Finish!" << std::endl; bool Refine(const std::vector<SizeType>& indices, std::string sDataPointsFileName)
{
std::cout << "Save Refine " << name << " To " << sDataPointsFileName << std::endl;
std::ofstream output(sDataPointsFileName, std::ios::binary);
if (!output.is_open()) return false;
Refine(indices, output);
output.close();
return true; return true;
} }
}; };
......
...@@ -199,7 +199,7 @@ namespace SPTAG ...@@ -199,7 +199,7 @@ namespace SPTAG
#endif #endif
/* /*
template<typename T> template<typename T>
static float ComputeL2Distance(const T *pX, const T *pY, int length) static float ComputeL2Distance(const T *pX, const T *pY, DimensionType length)
{ {
float diff = 0; float diff = 0;
const T* pEnd1 = pX + length; const T* pEnd1 = pX + length;
...@@ -217,7 +217,7 @@ namespace SPTAG ...@@ -217,7 +217,7 @@ namespace SPTAG
result = acc(result, exec(c1, c2)); \ result = acc(result, exec(c1, c2)); \
} \ } \
static float ComputeL2Distance(const std::int8_t *pX, const std::int8_t *pY, int length) static float ComputeL2Distance(const std::int8_t *pX, const std::int8_t *pY, DimensionType length)
{ {
const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); const std::int8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); const std::int8_t* pEnd16 = pX + ((length >> 4) << 4);
...@@ -258,7 +258,7 @@ namespace SPTAG ...@@ -258,7 +258,7 @@ namespace SPTAG
return diff; return diff;
} }
static float ComputeL2Distance(const std::uint8_t *pX, const std::uint8_t *pY, int length) static float ComputeL2Distance(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length)
{ {
const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4);
...@@ -299,7 +299,7 @@ namespace SPTAG ...@@ -299,7 +299,7 @@ namespace SPTAG
return diff; return diff;
} }
static float ComputeL2Distance(const std::int16_t *pX, const std::int16_t *pY, int length) static float ComputeL2Distance(const std::int16_t *pX, const std::int16_t *pY, DimensionType length)
{ {
const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); const std::int16_t* pEnd16 = pX + ((length >> 4) << 4);
const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); const std::int16_t* pEnd8 = pX + ((length >> 3) << 3);
...@@ -341,7 +341,7 @@ namespace SPTAG ...@@ -341,7 +341,7 @@ namespace SPTAG
return diff; return diff;
} }
static float ComputeL2Distance(const float *pX, const float *pY, int length) static float ComputeL2Distance(const float *pX, const float *pY, DimensionType length)
{ {
const float* pEnd16 = pX + ((length >> 4) << 4); const float* pEnd16 = pX + ((length >> 4) << 4);
const float* pEnd4 = pX + ((length >> 2) << 2); const float* pEnd4 = pX + ((length >> 2) << 2);
...@@ -389,14 +389,14 @@ namespace SPTAG ...@@ -389,14 +389,14 @@ namespace SPTAG
} }
/* /*
template<typename T> template<typename T>
static float ComputeCosineDistance(const T *pX, const T *pY, int length) { static float ComputeCosineDistance(const T *pX, const T *pY, DimensionType length) {
float diff = 0; float diff = 0;
const T* pEnd1 = pX + length; const T* pEnd1 = pX + length;
while (pX < pEnd1) diff += (*pX++) * (*pY++); while (pX < pEnd1) diff += (*pX++) * (*pY++);
return 1 - diff; return 1 - diff;
} }
*/ */
static float ComputeCosineDistance(const std::int8_t *pX, const std::int8_t *pY, int length) { static float ComputeCosineDistance(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) {
const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); const std::int8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); const std::int8_t* pEnd16 = pX + ((length >> 4) << 4);
const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); const std::int8_t* pEnd4 = pX + ((length >> 2) << 2);
...@@ -436,7 +436,7 @@ namespace SPTAG ...@@ -436,7 +436,7 @@ namespace SPTAG
return 16129 - diff; return 16129 - diff;
} }
static float ComputeCosineDistance(const std::uint8_t *pX, const std::uint8_t *pY, int length) { static float ComputeCosineDistance(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) {
const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4);
const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2);
...@@ -476,7 +476,7 @@ namespace SPTAG ...@@ -476,7 +476,7 @@ namespace SPTAG
return 65025 - diff; return 65025 - diff;
} }
static float ComputeCosineDistance(const std::int16_t *pX, const std::int16_t *pY, int length) { static float ComputeCosineDistance(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) {
const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); const std::int16_t* pEnd16 = pX + ((length >> 4) << 4);
const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); const std::int16_t* pEnd8 = pX + ((length >> 3) << 3);
const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); const std::int16_t* pEnd4 = pX + ((length >> 2) << 2);
...@@ -517,7 +517,7 @@ namespace SPTAG ...@@ -517,7 +517,7 @@ namespace SPTAG
return 1073676289 - diff; return 1073676289 - diff;
} }
static float ComputeCosineDistance(const float *pX, const float *pY, int length) { static float ComputeCosineDistance(const float *pX, const float *pY, DimensionType length) {
const float* pEnd16 = pX + ((length >> 4) << 4); const float* pEnd16 = pX + ((length >> 4) << 4);
const float* pEnd4 = pX + ((length >> 2) << 2); const float* pEnd4 = pX + ((length >> 2) << 2);
const float* pEnd1 = pX + length; const float* pEnd1 = pX + length;
...@@ -564,7 +564,7 @@ namespace SPTAG ...@@ -564,7 +564,7 @@ namespace SPTAG
} }
template<typename T> template<typename T>
static inline float ComputeDistance(const T *p1, const T *p2, int length, SPTAG::DistCalcMethod distCalcMethod) static inline float ComputeDistance(const T *p1, const T *p2, DimensionType length, SPTAG::DistCalcMethod distCalcMethod)
{ {
if (distCalcMethod == SPTAG::DistCalcMethod::L2) if (distCalcMethod == SPTAG::DistCalcMethod::L2)
return ComputeL2Distance(p1, p2, length); return ComputeL2Distance(p1, p2, length);
...@@ -588,7 +588,7 @@ namespace SPTAG ...@@ -588,7 +588,7 @@ namespace SPTAG
template<typename T> template<typename T>
float (*DistanceCalcSelector(SPTAG::DistCalcMethod p_method)) (const T*, const T*, int) float (*DistanceCalcSelector(SPTAG::DistCalcMethod p_method)) (const T*, const T*, DimensionType)
{ {
switch (p_method) switch (p_method)
{ {
......
...@@ -16,30 +16,30 @@ namespace SPTAG ...@@ -16,30 +16,30 @@ namespace SPTAG
public: public:
FineGrainedLock() {} FineGrainedLock() {}
~FineGrainedLock() { ~FineGrainedLock() {
for (int i = 0; i < locks.size(); i++) for (size_t i = 0; i < locks.size(); i++)
locks[i].reset(); locks[i].reset();
locks.clear(); locks.clear();
} }
void resize(int n) { void resize(SizeType n) {
int current = (int)locks.size(); SizeType current = (SizeType)locks.size();
if (current <= n) { if (current <= n) {
locks.resize(n); locks.resize(n);
for (int i = current; i < n; i++) for (SizeType i = current; i < n; i++)
locks[i].reset(new std::mutex); locks[i].reset(new std::mutex);
} }
else { else {
for (int i = n; i < current; i++) for (SizeType i = n; i < current; i++)
locks[i].reset(); locks[i].reset();
locks.resize(n); locks.resize(n);
} }
} }
std::mutex& operator[](int idx) { std::mutex& operator[](SizeType idx) {
return *locks[idx]; return *locks[idx];
} }
const std::mutex& operator[](int idx) const { const std::mutex& operator[](SizeType idx) const {
return *locks[idx]; return *locks[idx];
} }
private: private:
......
...@@ -23,9 +23,9 @@ namespace SPTAG ...@@ -23,9 +23,9 @@ namespace SPTAG
// node type for storing KDT // node type for storing KDT
struct KDTNode struct KDTNode
{ {
int left; SizeType left;
int right; SizeType right;
short split_dim; DimensionType split_dim;
float split_value; float split_value;
}; };
...@@ -39,18 +39,18 @@ namespace SPTAG ...@@ -39,18 +39,18 @@ namespace SPTAG
m_iSamples(other.m_iSamples) {} m_iSamples(other.m_iSamples) {}
~KDTree() {} ~KDTree() {}
inline const KDTNode& operator[](int index) const { return m_pTreeRoots[index]; } inline const KDTNode& operator[](SizeType index) const { return m_pTreeRoots[index]; }
inline KDTNode& operator[](int index) { return m_pTreeRoots[index]; } inline KDTNode& operator[](SizeType index) { return m_pTreeRoots[index]; }
inline int size() const { return (int)m_pTreeRoots.size(); } inline SizeType size() const { return (SizeType)m_pTreeRoots.size(); }
template <typename T> template <typename T>
void BuildTrees(VectorIndex* p_index, std::vector<int>* indices = nullptr) void BuildTrees(VectorIndex* p_index, std::vector<SizeType>* indices = nullptr)
{ {
std::vector<int> localindices; std::vector<SizeType> localindices;
if (indices == nullptr) { if (indices == nullptr) {
localindices.resize(p_index->GetNumSamples()); localindices.resize(p_index->GetNumSamples());
for (int i = 0; i < p_index->GetNumSamples(); i++) localindices[i] = i; for (SizeType i = 0; i < p_index->GetNumSamples(); i++) localindices[i] = i;
} }
else { else {
localindices.assign(indices->begin(), indices->end()); localindices.assign(indices->begin(), indices->end());
...@@ -63,58 +63,41 @@ namespace SPTAG ...@@ -63,58 +63,41 @@ namespace SPTAG
{ {
Sleep(i * 100); std::srand(clock()); Sleep(i * 100); std::srand(clock());
std::vector<int> pindices(localindices.begin(), localindices.end()); std::vector<SizeType> pindices(localindices.begin(), localindices.end());
std::random_shuffle(pindices.begin(), pindices.end()); std::random_shuffle(pindices.begin(), pindices.end());
m_pTreeStart[i] = i * (int)pindices.size(); m_pTreeStart[i] = i * (SizeType)pindices.size();
std::cout << "Start to build KDTree " << i + 1 << std::endl; std::cout << "Start to build KDTree " << i + 1 << std::endl;
int iTreeSize = m_pTreeStart[i]; SizeType iTreeSize = m_pTreeStart[i];
DivideTree<T>(p_index, pindices, 0, (int)pindices.size() - 1, m_pTreeStart[i], iTreeSize); DivideTree<T>(p_index, pindices, 0, (SizeType)pindices.size() - 1, m_pTreeStart[i], iTreeSize);
std::cout << i + 1 << " KDTree built, " << iTreeSize - m_pTreeStart[i] << " " << pindices.size() << std::endl; std::cout << i + 1 << " KDTree built, " << iTreeSize - m_pTreeStart[i] << " " << pindices.size() << std::endl;
} }
} }
bool SaveTrees(void **pKDTMemFile, int64_t &len) const inline std::uint64_t BufferSize() const
{ {
int treeNodeSize = (int)m_pTreeRoots.size(); return sizeof(int) + sizeof(SizeType) * m_iTreeNumber +
sizeof(SizeType) + sizeof(KDTNode) * m_pTreeRoots.size();
size_t size = sizeof(int) + }
sizeof(int) * m_iTreeNumber +
sizeof(int) +
sizeof(KDTNode) * treeNodeSize;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
auto ptr = mem;
*(int*)ptr = m_iTreeNumber;
ptr += sizeof(int);
memcpy(ptr, m_pTreeStart.data(), sizeof(int) * m_iTreeNumber);
ptr += sizeof(int) * m_iTreeNumber;
*(int*)ptr = treeNodeSize;
ptr += sizeof(int);
memcpy(ptr, m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize);
*pKDTMemFile = mem;
len = size;
bool SaveTrees(std::ostream& p_outstream) const
{
p_outstream.write((char*)&m_iTreeNumber, sizeof(int));
p_outstream.write((char*)m_pTreeStart.data(), sizeof(SizeType) * m_iTreeNumber);
SizeType treeNodeSize = (SizeType)m_pTreeRoots.size();
p_outstream.write((char*)&treeNodeSize, sizeof(SizeType));
p_outstream.write((char*)m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize);
std::cout << "Save KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true; return true;
} }
bool SaveTrees(std::string sTreeFileName) const bool SaveTrees(std::string sTreeFileName) const
{ {
std::cout << "Save KDT to " << sTreeFileName << std::endl; std::cout << "Save KDT to " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "wb"); std::ofstream output(sTreeFileName, std::ios::binary);
if (fp == NULL) return false; if (!output.is_open()) return false;
SaveTrees(output);
fwrite(&m_iTreeNumber, sizeof(int), 1, fp); output.close();
fwrite(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize = (int)m_pTreeRoots.size();
fwrite(&treeNodeSize, sizeof(int), 1, fp);
fwrite(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Save KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true; return true;
} }
...@@ -123,31 +106,32 @@ namespace SPTAG ...@@ -123,31 +106,32 @@ namespace SPTAG
m_iTreeNumber = *((int*)pKDTMemFile); m_iTreeNumber = *((int*)pKDTMemFile);
pKDTMemFile += sizeof(int); pKDTMemFile += sizeof(int);
m_pTreeStart.resize(m_iTreeNumber); m_pTreeStart.resize(m_iTreeNumber);
memcpy(m_pTreeStart.data(), pKDTMemFile, sizeof(int) * m_iTreeNumber); memcpy(m_pTreeStart.data(), pKDTMemFile, sizeof(SizeType) * m_iTreeNumber);
pKDTMemFile += sizeof(int)*m_iTreeNumber; pKDTMemFile += sizeof(SizeType)*m_iTreeNumber;
int treeNodeSize = *((int*)pKDTMemFile); SizeType treeNodeSize = *((SizeType*)pKDTMemFile);
pKDTMemFile += sizeof(int); pKDTMemFile += sizeof(SizeType);
m_pTreeRoots.resize(treeNodeSize); m_pTreeRoots.resize(treeNodeSize);
memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize); memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize);
std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true; return true;
} }
bool LoadTrees(std::string sTreeFileName) bool LoadTrees(std::string sTreeFileName)
{ {
std::cout << "Load KDT From " << sTreeFileName << std::endl; std::cout << "Load KDT From " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "rb"); std::ifstream input(sTreeFileName, std::ios::binary);
if (fp == NULL) return false; if (!input.is_open()) return false;
fread(&m_iTreeNumber, sizeof(int), 1, fp); input.read((char*)&m_iTreeNumber, sizeof(int));
m_pTreeStart.resize(m_iTreeNumber); m_pTreeStart.resize(m_iTreeNumber);
fread(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp); input.read((char*)m_pTreeStart.data(), sizeof(SizeType) * m_iTreeNumber);
int treeNodeSize; SizeType treeNodeSize;
fread(&treeNodeSize, sizeof(int), 1, fp); input.read((char*)&treeNodeSize, sizeof(SizeType));
m_pTreeRoots.resize(treeNodeSize); m_pTreeRoots.resize(treeNodeSize);
fread(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp); input.read((char*)m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize);
fclose(fp); input.close();
std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl; std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true; return true;
} }
...@@ -155,7 +139,7 @@ namespace SPTAG ...@@ -155,7 +139,7 @@ namespace SPTAG
template <typename T> template <typename T>
void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
{ {
for (char i = 0; i < m_iTreeNumber; i++) { for (int i = 0; i < m_iTreeNumber; i++) {
KDTSearch(p_index, p_query, p_space, m_pTreeStart[i], true, 0); KDTSearch(p_index, p_query, p_space, m_pTreeStart[i], true, 0);
} }
...@@ -181,10 +165,10 @@ namespace SPTAG ...@@ -181,10 +165,10 @@ namespace SPTAG
template <typename T> template <typename T>
void KDTSearch(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, void KDTSearch(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query,
COMMON::WorkSpace& p_space, const int node, const bool isInit, const float distBound) const { COMMON::WorkSpace& p_space, const SizeType node, const bool isInit, const float distBound) const {
if (node < 0) if (node < 0)
{ {
int index = -node - 1; SizeType index = -node - 1;
if (index >= p_index->GetNumSamples()) return; if (index >= p_index->GetNumSamples()) return;
#ifdef PREFETCH #ifdef PREFETCH
const char* data = (const char *)(p_index->GetSample(index)); const char* data = (const char *)(p_index->GetSample(index));
...@@ -203,7 +187,7 @@ namespace SPTAG ...@@ -203,7 +187,7 @@ namespace SPTAG
float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value; float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value;
float distanceBound = distBound + diff * diff; float distanceBound = distBound + diff * diff;
int otherChild, bestChild; SizeType otherChild, bestChild;
if (diff < 0) if (diff < 0)
{ {
bestChild = tnode.left; bestChild = tnode.left;
...@@ -224,10 +208,10 @@ namespace SPTAG ...@@ -224,10 +208,10 @@ namespace SPTAG
template <typename T> template <typename T>
void DivideTree(VectorIndex* p_index, std::vector<int>& indices, int first, int last, void DivideTree(VectorIndex* p_index, std::vector<SizeType>& indices, SizeType first, SizeType last,
int index, int &iTreeSize) { SizeType index, SizeType &iTreeSize) {
ChooseDivision<T>(p_index, m_pTreeRoots[index], indices, first, last); ChooseDivision<T>(p_index, m_pTreeRoots[index], indices, first, last);
int i = Subdivide<T>(p_index, m_pTreeRoots[index], indices, first, last); SizeType i = Subdivide<T>(p_index, m_pTreeRoots[index], indices, first, last);
if (i - 1 <= first) if (i - 1 <= first)
{ {
m_pTreeRoots[index].left = -indices[first] - 1; m_pTreeRoots[index].left = -indices[first] - 1;
...@@ -251,30 +235,30 @@ namespace SPTAG ...@@ -251,30 +235,30 @@ namespace SPTAG
} }
template <typename T> template <typename T>
void ChooseDivision(VectorIndex* p_index, KDTNode& node, const std::vector<int>& indices, const int first, const int last) void ChooseDivision(VectorIndex* p_index, KDTNode& node, const std::vector<SizeType>& indices, const SizeType first, const SizeType last)
{ {
std::vector<float> meanValues(p_index->GetFeatureDim(), 0); std::vector<float> meanValues(p_index->GetFeatureDim(), 0);
std::vector<float> varianceValues(p_index->GetFeatureDim(), 0); std::vector<float> varianceValues(p_index->GetFeatureDim(), 0);
int end = min(first + m_iSamples, last); SizeType end = min(first + m_iSamples, last);
int count = end - first + 1; SizeType count = end - first + 1;
// calculate the mean of each dimension // calculate the mean of each dimension
for (int j = first; j <= end; j++) for (SizeType j = first; j <= end; j++)
{ {
const T* v = (const T*)p_index->GetSample(indices[j]); const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++) for (DimensionType k = 0; k < p_index->GetFeatureDim(); k++)
{ {
meanValues[k] += v[k]; meanValues[k] += v[k];
} }
} }
for (int k = 0; k < p_index->GetFeatureDim(); k++) for (DimensionType k = 0; k < p_index->GetFeatureDim(); k++)
{ {
meanValues[k] /= count; meanValues[k] /= count;
} }
// calculate the variance of each dimension // calculate the variance of each dimension
for (int j = first; j <= end; j++) for (SizeType j = first; j <= end; j++)
{ {
const T* v = (const T*)p_index->GetSample(indices[j]); const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++) for (DimensionType k = 0; k < p_index->GetFeatureDim(); k++)
{ {
float dist = v[k] - meanValues[k]; float dist = v[k] - meanValues[k];
varianceValues[k] += dist*dist; varianceValues[k] += dist*dist;
...@@ -286,13 +270,13 @@ namespace SPTAG ...@@ -286,13 +270,13 @@ namespace SPTAG
node.split_value = meanValues[node.split_dim]; node.split_value = meanValues[node.split_dim];
} }
int SelectDivisionDimension(const std::vector<float>& varianceValues) const DimensionType SelectDivisionDimension(const std::vector<float>& varianceValues) const
{ {
// Record the top maximum variances // Record the top maximum variances
std::vector<int> topind(m_numTopDimensionKDTSplit); std::vector<DimensionType> topind(m_numTopDimensionKDTSplit);
int num = 0; int num = 0;
// order the variances // order the variances
for (int i = 0; i < varianceValues.size(); i++) for (DimensionType i = 0; i < (DimensionType)varianceValues.size(); i++)
{ {
if (num < m_numTopDimensionKDTSplit || varianceValues[i] > varianceValues[topind[num - 1]]) if (num < m_numTopDimensionKDTSplit || varianceValues[i] > varianceValues[topind[num - 1]])
{ {
...@@ -314,18 +298,18 @@ namespace SPTAG ...@@ -314,18 +298,18 @@ namespace SPTAG
} }
} }
// randomly choose a dimension from TOP_DIM // randomly choose a dimension from TOP_DIM
return topind[COMMON::Utils::rand_int(num)]; return topind[COMMON::Utils::rand(num)];
} }
template <typename T> template <typename T>
int Subdivide(VectorIndex* p_index, const KDTNode& node, std::vector<int>& indices, const int first, const int last) const SizeType Subdivide(VectorIndex* p_index, const KDTNode& node, std::vector<SizeType>& indices, const SizeType first, const SizeType last) const
{ {
int i = first; SizeType i = first;
int j = last; SizeType j = last;
// decide which child one point belongs // decide which child one point belongs
while (i <= j) while (i <= j)
{ {
int ind = indices[i]; SizeType ind = indices[i];
const T* v = (const T*)p_index->GetSample(ind); const T* v = (const T*)p_index->GetSample(ind);
float val = v[node.split_dim]; float val = v[node.split_dim];
if (val < node.split_value) if (val < node.split_value)
...@@ -347,7 +331,7 @@ namespace SPTAG ...@@ -347,7 +331,7 @@ namespace SPTAG
} }
private: private:
std::vector<int> m_pTreeStart; std::vector<SizeType> m_pTreeStart;
std::vector<KDTNode> m_pTreeRoots; std::vector<KDTNode> m_pTreeRoots;
public: public:
......
...@@ -27,18 +27,21 @@ namespace SPTAG ...@@ -27,18 +27,21 @@ namespace SPTAG
m_iCEFScale(2), m_iCEFScale(2),
m_iRefineIter(0), m_iRefineIter(0),
m_iCEF(1000), m_iCEF(1000),
m_iMaxCheckForRefineGraph(10000) {} m_iMaxCheckForRefineGraph(10000)
{
m_pNeighborhoodGraph.SetName("Graph");
}
~NeighborhoodGraph() {} ~NeighborhoodGraph() {}
virtual void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist) = 0; virtual void InsertNeighbors(VectorIndex* index, const SizeType node, SizeType insertNode, float insertDist) = 0;
virtual void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) = 0; virtual void RebuildNeighbors(VectorIndex* index, const SizeType node, SizeType* nodes, const BasicResult* queryResults, const int numResults) = 0;
virtual float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr) = 0; virtual float GraphAccuracyEstimation(VectorIndex* index, const SizeType samples, const std::unordered_map<SizeType, SizeType>* idmap = nullptr) = 0;
template <typename T> template <typename T>
void BuildGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr) void BuildGraph(VectorIndex* index, const std::unordered_map<SizeType, SizeType>* idmap = nullptr)
{ {
std::cout << "build RNG graph!" << std::endl; std::cout << "build RNG graph!" << std::endl;
...@@ -55,11 +58,11 @@ namespace SPTAG ...@@ -55,11 +58,11 @@ namespace SPTAG
{ {
COMMON::Dataset<float> NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize); COMMON::Dataset<float> NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize);
std::vector<std::vector<int>> TptreeDataIndices(m_iTPTNumber, std::vector<int>(m_iGraphSize)); std::vector<std::vector<SizeType>> TptreeDataIndices(m_iTPTNumber, std::vector<SizeType>(m_iGraphSize));
std::vector<std::vector<std::pair<int, int>>> TptreeLeafNodes(m_iTPTNumber, std::vector<std::pair<int, int>>()); std::vector<std::vector<std::pair<SizeType, SizeType>>> TptreeLeafNodes(m_iTPTNumber, std::vector<std::pair<SizeType, SizeType>>());
for (int i = 0; i < m_iGraphSize; i++) for (SizeType i = 0; i < m_iGraphSize; i++)
for (int j = 0; j < m_iNeighborhoodSize; j++) for (DimensionType j = 0; j < m_iNeighborhoodSize; j++)
(NeighborhoodDists)[i][j] = MaxDist; (NeighborhoodDists)[i][j] = MaxDist;
std::cout << "Parallel TpTree Partition begin " << std::endl; std::cout << "Parallel TpTree Partition begin " << std::endl;
...@@ -67,7 +70,7 @@ namespace SPTAG ...@@ -67,7 +70,7 @@ namespace SPTAG
for (int i = 0; i < m_iTPTNumber; i++) for (int i = 0; i < m_iTPTNumber; i++)
{ {
Sleep(i * 100); std::srand(clock()); Sleep(i * 100); std::srand(clock());
for (int j = 0; j < m_iGraphSize; j++) TptreeDataIndices[i][j] = j; for (SizeType j = 0; j < m_iGraphSize; j++) TptreeDataIndices[i][j] = j;
std::random_shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end()); std::random_shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end());
PartitionByTptree<T>(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, TptreeLeafNodes[i]); PartitionByTptree<T>(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, TptreeLeafNodes[i]);
std::cout << "Finish Getting Leaves for Tree " << i << std::endl; std::cout << "Finish Getting Leaves for Tree " << i << std::endl;
...@@ -77,17 +80,17 @@ namespace SPTAG ...@@ -77,17 +80,17 @@ namespace SPTAG
for (int i = 0; i < m_iTPTNumber; i++) for (int i = 0; i < m_iTPTNumber; i++)
{ {
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int j = 0; j < TptreeLeafNodes[i].size(); j++) for (SizeType j = 0; j < (SizeType)TptreeLeafNodes[i].size(); j++)
{ {
int start_index = TptreeLeafNodes[i][j].first; SizeType start_index = TptreeLeafNodes[i][j].first;
int end_index = TptreeLeafNodes[i][j].second; SizeType end_index = TptreeLeafNodes[i][j].second;
if (omp_get_thread_num() == 0) std::cout << "\rProcessing Tree " << i << ' ' << j * 100 / TptreeLeafNodes[i].size() << '%'; if (omp_get_thread_num() == 0) std::cout << "\rProcessing Tree " << i << ' ' << j * 100 / TptreeLeafNodes[i].size() << '%';
for (int x = start_index; x < end_index; x++) for (SizeType x = start_index; x < end_index; x++)
{ {
for (int y = x + 1; y <= end_index; y++) for (SizeType y = x + 1; y <= end_index; y++)
{ {
int p1 = TptreeDataIndices[i][x]; SizeType p1 = TptreeDataIndices[i][x];
int p2 = TptreeDataIndices[i][y]; SizeType p2 = TptreeDataIndices[i][y];
float dist = index->ComputeDistance(index->GetSample(p1), index->GetSample(p2)); float dist = index->ComputeDistance(index->GetSample(p1), index->GetSample(p2));
if (idmap != nullptr) { if (idmap != nullptr) {
p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1); p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1);
...@@ -112,13 +115,13 @@ namespace SPTAG ...@@ -112,13 +115,13 @@ namespace SPTAG
} }
template <typename T> template <typename T>
void RefineGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr) void RefineGraph(VectorIndex* index, const std::unordered_map<SizeType, SizeType>* idmap = nullptr)
{ {
m_iCEF *= m_iCEFScale; m_iCEF *= m_iCEFScale;
m_iMaxCheckForRefineGraph *= m_iCEFScale; m_iMaxCheckForRefineGraph *= m_iCEFScale;
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iGraphSize; i++) for (SizeType i = 0; i < m_iGraphSize; i++)
{ {
RefineNode<T>(index, i, false); RefineNode<T>(index, i, false);
if (i % 1000 == 0) std::cout << "\rRefine 1 " << (i * 100 / m_iGraphSize) << "%"; if (i % 1000 == 0) std::cout << "\rRefine 1 " << (i * 100 / m_iGraphSize) << "%";
...@@ -130,7 +133,7 @@ namespace SPTAG ...@@ -130,7 +133,7 @@ namespace SPTAG
m_iNeighborhoodSize /= m_iNeighborhoodScale; m_iNeighborhoodSize /= m_iNeighborhoodScale;
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iGraphSize; i++) for (SizeType i = 0; i < m_iGraphSize; i++)
{ {
RefineNode<T>(index, i, false); RefineNode<T>(index, i, false);
if (i % 1000 == 0) std::cout << "\rRefine 2 " << (i * 100 / m_iGraphSize) << "%"; if (i % 1000 == 0) std::cout << "\rRefine 2 " << (i * 100 / m_iGraphSize) << "%";
...@@ -147,17 +150,17 @@ namespace SPTAG ...@@ -147,17 +150,17 @@ namespace SPTAG
} }
template <typename T> template <typename T>
ErrorCode RefineGraph(VectorIndex* index, std::vector<int>& indices, std::vector<int>& reverseIndices, ErrorCode RefineGraph(VectorIndex* index, std::vector<SizeType>& indices, std::vector<SizeType>& reverseIndices,
std::string graphFileName, const std::unordered_map<int, int>* idmap = nullptr) std::ostream& output, const std::unordered_map<SizeType, SizeType>* idmap = nullptr)
{ {
int R = (int)indices.size(); SizeType R = (SizeType)indices.size();
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < R; i++) for (SizeType i = 0; i < R; i++)
{ {
RefineNode<T>(index, indices[i], false); RefineNode<T>(index, indices[i], false);
int* nodes = m_pNeighborhoodGraph[indices[i]]; SizeType* nodes = m_pNeighborhoodGraph[indices[i]];
for (int j = 0; j < m_iNeighborhoodSize; j++) for (DimensionType j = 0; j < m_iNeighborhoodSize; j++)
{ {
if (nodes[j] < 0) nodes[j] = -1; if (nodes[j] < 0) nodes[j] = -1;
else nodes[j] = reverseIndices[nodes[j]]; else nodes[j] = reverseIndices[nodes[j]];
...@@ -166,20 +169,13 @@ namespace SPTAG ...@@ -166,20 +169,13 @@ namespace SPTAG
nodes[m_iNeighborhoodSize - 1] = -2 - idmap->at(-1 - indices[i]); nodes[m_iNeighborhoodSize - 1] = -2 - idmap->at(-1 - indices[i]);
} }
std::ofstream graphOut(graphFileName, std::ios::binary); m_pNeighborhoodGraph.Refine(indices, output);
if (!graphOut.is_open()) return ErrorCode::FailedCreateFile;
graphOut.write((char*)&R, sizeof(int));
graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int));
for (int i = 0; i < R; i++) {
graphOut.write((char*)m_pNeighborhoodGraph[indices[i]], sizeof(int) * m_iNeighborhoodSize);
}
graphOut.close();
return ErrorCode::Success; return ErrorCode::Success;
} }
template <typename T> template <typename T>
void RefineNode(VectorIndex* index, const int node, bool updateNeighbors) void RefineNode(VectorIndex* index, const SizeType node, bool updateNeighbors)
{ {
COMMON::QueryResultSet<T> query((const T*)index->GetSample(node), m_iCEF + 1); COMMON::QueryResultSet<T> query((const T*)index->GetSample(node), m_iCEF + 1);
index->SearchIndex(query); index->SearchIndex(query);
...@@ -200,8 +196,8 @@ namespace SPTAG ...@@ -200,8 +196,8 @@ namespace SPTAG
} }
template <typename T> template <typename T>
void PartitionByTptree(VectorIndex* index, std::vector<int>& indices, const int first, const int last, void PartitionByTptree(VectorIndex* index, std::vector<SizeType>& indices, const SizeType first, const SizeType last,
std::vector<std::pair<int, int>> & leaves) std::vector<std::pair<SizeType, SizeType>> & leaves)
{ {
if (last - first <= m_iTPTLeafSize) if (last - first <= m_iTPTLeafSize)
{ {
...@@ -212,39 +208,39 @@ namespace SPTAG ...@@ -212,39 +208,39 @@ namespace SPTAG
std::vector<float> Mean(index->GetFeatureDim(), 0); std::vector<float> Mean(index->GetFeatureDim(), 0);
int iIteration = 100; int iIteration = 100;
int end = min(first + m_iSamples, last); SizeType end = min(first + m_iSamples, last);
int count = end - first + 1; SizeType count = end - first + 1;
// calculate the mean of each dimension // calculate the mean of each dimension
for (int j = first; j <= end; j++) for (SizeType j = first; j <= end; j++)
{ {
const T* v = (const T*)index->GetSample(indices[j]); const T* v = (const T*)index->GetSample(indices[j]);
for (int k = 0; k < index->GetFeatureDim(); k++) for (DimensionType k = 0; k < index->GetFeatureDim(); k++)
{ {
Mean[k] += v[k]; Mean[k] += v[k];
} }
} }
for (int k = 0; k < index->GetFeatureDim(); k++) for (DimensionType k = 0; k < index->GetFeatureDim(); k++)
{ {
Mean[k] /= count; Mean[k] /= count;
} }
std::vector<BasicResult> Variance; std::vector<BasicResult> Variance;
Variance.reserve(index->GetFeatureDim()); Variance.reserve(index->GetFeatureDim());
for (int j = 0; j < index->GetFeatureDim(); j++) for (DimensionType j = 0; j < index->GetFeatureDim(); j++)
{ {
Variance.push_back(BasicResult(j, 0)); Variance.push_back(BasicResult(j, 0));
} }
// calculate the variance of each dimension // calculate the variance of each dimension
for (int j = first; j <= end; j++) for (SizeType j = first; j <= end; j++)
{ {
const T* v = (const T*)index->GetSample(indices[j]); const T* v = (const T*)index->GetSample(indices[j]);
for (int k = 0; k < index->GetFeatureDim(); k++) for (DimensionType k = 0; k < index->GetFeatureDim(); k++)
{ {
float dist = v[k] - Mean[k]; float dist = v[k] - Mean[k];
Variance[k].Dist += dist*dist; Variance[k].Dist += dist*dist;
} }
} }
std::sort(Variance.begin(), Variance.end(), COMMON::Compare); std::sort(Variance.begin(), Variance.end(), COMMON::Compare);
std::vector<int> indexs(m_numTopDimensionTPTSplit); std::vector<SizeType> indexs(m_numTopDimensionTPTSplit);
std::vector<float> weight(m_numTopDimensionTPTSplit), bestweight(m_numTopDimensionTPTSplit); std::vector<float> weight(m_numTopDimensionTPTSplit), bestweight(m_numTopDimensionTPTSplit);
float bestvariance = Variance[index->GetFeatureDim() - 1].Dist; float bestvariance = Variance[index->GetFeatureDim() - 1].Dist;
for (int i = 0; i < m_numTopDimensionTPTSplit; i++) for (int i = 0; i < m_numTopDimensionTPTSplit; i++)
...@@ -270,7 +266,7 @@ namespace SPTAG ...@@ -270,7 +266,7 @@ namespace SPTAG
weight[j] /= sumweight; weight[j] /= sumweight;
} }
float mean = 0; float mean = 0;
for (int j = 0; j < count; j++) for (SizeType j = 0; j < count; j++)
{ {
Val[j] = 0; Val[j] = 0;
const T* v = (const T*)index->GetSample(indices[first + j]); const T* v = (const T*)index->GetSample(indices[first + j]);
...@@ -282,7 +278,7 @@ namespace SPTAG ...@@ -282,7 +278,7 @@ namespace SPTAG
} }
mean /= count; mean /= count;
float var = 0; float var = 0;
for (int j = 0; j < count; j++) for (SizeType j = 0; j < count; j++)
{ {
float dist = Val[j] - mean; float dist = Val[j] - mean;
var += dist * dist; var += dist * dist;
...@@ -297,8 +293,8 @@ namespace SPTAG ...@@ -297,8 +293,8 @@ namespace SPTAG
} }
} }
} }
int i = first; SizeType i = first;
int j = last; SizeType j = last;
// decide which child one point belongs // decide which child one point belongs
while (i <= j) while (i <= j)
{ {
...@@ -336,100 +332,71 @@ namespace SPTAG ...@@ -336,100 +332,71 @@ namespace SPTAG
} }
} }
inline std::uint64_t BufferSize() const
{
return m_pNeighborhoodGraph.BufferSize();
}
bool LoadGraph(std::string sGraphFilename) bool LoadGraph(std::string sGraphFilename)
{ {
std::cout << "Load Graph From " << sGraphFilename << std::endl; if (!m_pNeighborhoodGraph.Load(sGraphFilename)) return false;
FILE * fp = fopen(sGraphFilename.c_str(), "rb");
if (fp == NULL) return false;
fread(&m_iGraphSize, sizeof(int), 1, fp); m_iGraphSize = m_pNeighborhoodGraph.R();
fread(&m_iNeighborhoodSize, sizeof(int), 1, fp); m_iNeighborhoodSize = m_pNeighborhoodGraph.C();
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
m_dataUpdateLock.resize(m_iGraphSize); m_dataUpdateLock.resize(m_iGraphSize);
for (int i = 0; i < m_iGraphSize; i++)
{
fread((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
}
fclose(fp);
std::cout << "Load Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
return true; return true;
} }
bool LoadGraphFromMemory(char* pGraphMemFile) bool LoadGraph(char* pGraphMemFile)
{ {
m_iGraphSize = *((int*)pGraphMemFile); m_pNeighborhoodGraph.Load(pGraphMemFile);
pGraphMemFile += sizeof(int);
m_iNeighborhoodSize = *((int*)pGraphMemFile); m_iGraphSize = m_pNeighborhoodGraph.R();
pGraphMemFile += sizeof(int); m_iNeighborhoodSize = m_pNeighborhoodGraph.C();
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize, (int*)pGraphMemFile);
m_dataUpdateLock.resize(m_iGraphSize); m_dataUpdateLock.resize(m_iGraphSize);
return true; return true;
} }
bool SaveGraph(std::string sGraphFilename) const bool SaveGraph(std::string sGraphFilename) const
{ {
std::cout << "Save Graph To " << sGraphFilename << std::endl; return m_pNeighborhoodGraph.Save(sGraphFilename);
FILE *fp = fopen(sGraphFilename.c_str(), "wb");
if (fp == NULL) return false;
fwrite(&m_iGraphSize, sizeof(int), 1, fp);
fwrite(&m_iNeighborhoodSize, sizeof(int), 1, fp);
for (int i = 0; i < m_iGraphSize; i++)
{
fwrite((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
}
fclose(fp);
std::cout << "Save Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
return true;
} }
bool SaveGraphToMemory(void **pGraphMemFile, int64_t &len) { bool SaveGraph(std::ostream& output) const
size_t size = sizeof(int) + sizeof(int) + sizeof(int) * m_iNeighborhoodSize * m_iGraphSize; {
char *mem = (char*)malloc(size); return m_pNeighborhoodGraph.Save(output);
if (mem == NULL) return false; }
auto ptr = mem;
*(int*)ptr = m_iGraphSize;
ptr += sizeof(int);
*(int*)ptr = m_iNeighborhoodSize;
ptr += sizeof(int);
for (int i = 0; i < m_iGraphSize; i++) inline ErrorCode AddBatch(SizeType num)
{ {
memcpy(ptr, (m_pNeighborhoodGraph)[i], sizeof(int) * m_iNeighborhoodSize); ErrorCode ret = m_pNeighborhoodGraph.AddBatch(num);
ptr += sizeof(int) * m_iNeighborhoodSize; if (ret != ErrorCode::Success) return ret;
}
*pGraphMemFile = mem;
len = size;
return true; m_iGraphSize += num;
m_dataUpdateLock.resize(m_iGraphSize);
return ErrorCode::Success;
} }
inline void AddBatch(int num) { m_pNeighborhoodGraph.AddBatch(num); m_iGraphSize += num; m_dataUpdateLock.resize(m_iGraphSize); } inline SizeType* operator[](SizeType index) { return m_pNeighborhoodGraph[index]; }
inline int* operator[](int index) { return m_pNeighborhoodGraph[index]; }
inline const int* operator[](int index) const { return m_pNeighborhoodGraph[index]; } inline const SizeType* operator[](SizeType index) const { return m_pNeighborhoodGraph[index]; }
inline void SetR(int rows) { m_pNeighborhoodGraph.SetR(rows); m_iGraphSize = rows; m_dataUpdateLock.resize(m_iGraphSize); } inline void SetR(SizeType rows) { m_pNeighborhoodGraph.SetR(rows); m_iGraphSize = rows; m_dataUpdateLock.resize(m_iGraphSize); }
inline int R() const { return m_iGraphSize; } inline SizeType R() const { return m_iGraphSize; }
static std::shared_ptr<NeighborhoodGraph> CreateInstance(std::string type); static std::shared_ptr<NeighborhoodGraph> CreateInstance(std::string type);
protected: protected:
// Graph structure // Graph structure
int m_iGraphSize; SizeType m_iGraphSize;
COMMON::Dataset<int> m_pNeighborhoodGraph; COMMON::Dataset<SizeType> m_pNeighborhoodGraph;
COMMON::FineGrainedLock m_dataUpdateLock; // protect one row of the graph COMMON::FineGrainedLock m_dataUpdateLock; // protect one row of the graph
public: public:
int m_iTPTNumber, m_iTPTLeafSize, m_iSamples, m_numTopDimensionTPTSplit; int m_iTPTNumber, m_iTPTLeafSize, m_iSamples, m_numTopDimensionTPTSplit;
int m_iNeighborhoodSize, m_iNeighborhoodScale, m_iCEFScale, m_iRefineIter, m_iCEF, m_iMaxCheckForRefineGraph; DimensionType m_iNeighborhoodSize;
int m_iNeighborhoodScale, m_iCEFScale, m_iRefineIter, m_iCEF, m_iMaxCheckForRefineGraph;
}; };
} }
} }
......
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
return m_results[0].Dist; return m_results[0].Dist;
} }
bool AddPoint(const int index, float dist) bool AddPoint(const SizeType index, float dist)
{ {
if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID)) if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID))
{ {
......
...@@ -13,15 +13,15 @@ namespace SPTAG ...@@ -13,15 +13,15 @@ namespace SPTAG
class RelativeNeighborhoodGraph: public NeighborhoodGraph class RelativeNeighborhoodGraph: public NeighborhoodGraph
{ {
public: public:
void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) { void RebuildNeighbors(VectorIndex* index, const SizeType node, SizeType* nodes, const BasicResult* queryResults, const int numResults) {
int count = 0; DimensionType count = 0;
for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) { for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) {
const BasicResult& item = queryResults[j]; const BasicResult& item = queryResults[j];
if (item.VID < 0) break; if (item.VID < 0) break;
if (item.VID == node) continue; if (item.VID == node) continue;
bool good = true; bool good = true;
for (int k = 0; k < count; k++) { for (DimensionType k = 0; k < count; k++) {
if (index->ComputeDistance(index->GetSample(nodes[k]), index->GetSample(item.VID)) <= item.Dist) { if (index->ComputeDistance(index->GetSample(nodes[k]), index->GetSample(item.VID)) <= item.Dist) {
good = false; good = false;
break; break;
...@@ -29,21 +29,21 @@ namespace SPTAG ...@@ -29,21 +29,21 @@ namespace SPTAG
} }
if (good) nodes[count++] = item.VID; if (good) nodes[count++] = item.VID;
} }
for (int j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1; for (DimensionType j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1;
} }
void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist) void InsertNeighbors(VectorIndex* index, const SizeType node, SizeType insertNode, float insertDist)
{ {
int* nodes = m_pNeighborhoodGraph[node]; SizeType* nodes = m_pNeighborhoodGraph[node];
for (int k = 0; k < m_iNeighborhoodSize; k++) for (DimensionType k = 0; k < m_iNeighborhoodSize; k++)
{ {
int tmpNode = nodes[k]; SizeType tmpNode = nodes[k];
if (tmpNode < -1) continue; if (tmpNode < -1) continue;
if (tmpNode < 0) if (tmpNode < 0)
{ {
bool good = true; bool good = true;
for (int t = 0; t < k; t++) { for (DimensionType t = 0; t < k; t++) {
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) { if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
good = false; good = false;
break; break;
...@@ -58,7 +58,7 @@ namespace SPTAG ...@@ -58,7 +58,7 @@ namespace SPTAG
if (insertDist < tmpDist || (insertDist == tmpDist && insertNode < tmpNode)) if (insertDist < tmpDist || (insertDist == tmpDist && insertNode < tmpNode))
{ {
bool good = true; bool good = true;
for (int t = 0; t < k; t++) { for (DimensionType t = 0; t < k; t++) {
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) { if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
good = false; good = false;
break; break;
...@@ -76,33 +76,33 @@ namespace SPTAG ...@@ -76,33 +76,33 @@ namespace SPTAG
} }
} }
float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr) float GraphAccuracyEstimation(VectorIndex* index, const SizeType samples, const std::unordered_map<SizeType, SizeType>* idmap = nullptr)
{ {
int* correct = new int[samples]; DimensionType* correct = new DimensionType[samples];
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < samples; i++) for (SizeType i = 0; i < samples; i++)
{ {
int x = COMMON::Utils::rand_int(m_iGraphSize); SizeType x = COMMON::Utils::rand(m_iGraphSize);
//int x = i; //int x = i;
COMMON::QueryResultSet<void> query(nullptr, m_iCEF); COMMON::QueryResultSet<void> query(nullptr, m_iCEF);
for (int y = 0; y < m_iGraphSize; y++) for (SizeType y = 0; y < m_iGraphSize; y++)
{ {
if ((idmap != nullptr && idmap->find(y) != idmap->end())) continue; if ((idmap != nullptr && idmap->find(y) != idmap->end())) continue;
float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y)); float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y));
query.AddPoint(y, dist); query.AddPoint(y, dist);
} }
query.SortResult(); query.SortResult();
int * exact_rng = new int[m_iNeighborhoodSize]; SizeType * exact_rng = new SizeType[m_iNeighborhoodSize];
RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF); RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF);
correct[i] = 0; correct[i] = 0;
for (int j = 0; j < m_iNeighborhoodSize; j++) { for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) {
if (exact_rng[j] == -1) { if (exact_rng[j] == -1) {
correct[i] += m_iNeighborhoodSize - j; correct[i] += m_iNeighborhoodSize - j;
break; break;
} }
for (int k = 0; k < m_iNeighborhoodSize; k++) for (DimensionType k = 0; k < m_iNeighborhoodSize; k++)
if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) { if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) {
correct[i]++; correct[i]++;
break; break;
...@@ -111,7 +111,7 @@ namespace SPTAG ...@@ -111,7 +111,7 @@ namespace SPTAG
delete[] exact_rng; delete[] exact_rng;
} }
float acc = 0; float acc = 0;
for (int i = 0; i < samples; i++) acc += float(correct[i]); for (SizeType i = 0; i < samples; i++) acc += float(correct[i]);
acc = acc / samples / m_iNeighborhoodSize; acc = acc / samples / m_iNeighborhoodSize;
delete[] correct; delete[] correct;
return acc; return acc;
......
...@@ -14,10 +14,10 @@ namespace SPTAG ...@@ -14,10 +14,10 @@ namespace SPTAG
// node type in the priority queue // node type in the priority queue
struct HeapCell struct HeapCell
{ {
int node; SizeType node;
float distance; float distance;
HeapCell(int _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {} HeapCell(SizeType _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {}
inline bool operator < (const HeapCell& rhs) inline bool operator < (const HeapCell& rhs)
{ {
...@@ -45,12 +45,12 @@ namespace SPTAG ...@@ -45,12 +45,12 @@ namespace SPTAG
// Record 2 hash tables. // Record 2 hash tables.
// [0~m_poolSize + 1) is the first block. // [0~m_poolSize + 1) is the first block.
// [m_poolSize + 1, 2*(m_poolSize + 1)) is the second block; // [m_poolSize + 1, 2*(m_poolSize + 1)) is the second block;
int m_hashTable[(m_poolSize + 1) * 2]; SizeType m_hashTable[(m_poolSize + 1) * 2];
inline unsigned hash_func2(int idx, int loop) inline unsigned hash_func2(unsigned idx, int loop)
{ {
return ((unsigned)idx + loop) & m_poolSize; return (idx + loop) & m_poolSize;
} }
...@@ -65,7 +65,7 @@ namespace SPTAG ...@@ -65,7 +65,7 @@ namespace SPTAG
~OptHashPosVector() {} ~OptHashPosVector() {}
void Init(int size) void Init(SizeType size)
{ {
m_secondHash = true; m_secondHash = true;
clear(); clear();
...@@ -76,31 +76,31 @@ namespace SPTAG ...@@ -76,31 +76,31 @@ namespace SPTAG
if (!m_secondHash) if (!m_secondHash)
{ {
// Clear first block. // Clear first block.
memset(&m_hashTable[0], 0, sizeof(int)*(m_poolSize + 1)); memset(&m_hashTable[0], 0, sizeof(SizeType)*(m_poolSize + 1));
} }
else else
{ {
// Clear all blocks. // Clear all blocks.
memset(&m_hashTable[0], 0, 2 * sizeof(int) * (m_poolSize + 1)); memset(&m_hashTable[0], 0, 2 * sizeof(SizeType) * (m_poolSize + 1));
m_secondHash = false; m_secondHash = false;
} }
} }
inline bool CheckAndSet(int idx) inline bool CheckAndSet(SizeType idx)
{ {
// Inner Index is begin from 1 // Inner Index is begin from 1
return _CheckAndSet(&m_hashTable[0], idx + 1) == 0; return _CheckAndSet(&m_hashTable[0], idx + 1) == 0;
} }
inline int _CheckAndSet(int* hashTable, int idx) inline int _CheckAndSet(SizeType* hashTable, SizeType idx)
{ {
unsigned index, loop; unsigned index;
// Get first hash position. // Get first hash position.
index = hash_func(idx); index = hash_func((unsigned)idx);
for (loop = 0; loop < m_maxLoop; ++loop) for (int loop = 0; loop < m_maxLoop; ++loop)
{ {
if (!hashTable[index]) if (!hashTable[index])
{ {
...@@ -132,7 +132,7 @@ namespace SPTAG ...@@ -132,7 +132,7 @@ namespace SPTAG
// Variables for each single NN search // Variables for each single NN search
struct WorkSpace struct WorkSpace
{ {
void Initialize(int maxCheck, int dataSize) void Initialize(int maxCheck, SizeType dataSize)
{ {
nodeCheckStatus.Init(dataSize); nodeCheckStatus.Init(dataSize);
m_SPTQueue.Resize(maxCheck * 10); m_SPTQueue.Resize(maxCheck * 10);
...@@ -158,7 +158,7 @@ namespace SPTAG ...@@ -158,7 +158,7 @@ namespace SPTAG
m_iNumOfContinuousNoBetterPropagation = 0; m_iNumOfContinuousNoBetterPropagation = 0;
} }
inline bool CheckAndSet(int idx) inline bool CheckAndSet(SizeType idx)
{ {
return nodeCheckStatus.CheckAndSet(idx); return nodeCheckStatus.CheckAndSet(idx);
} }
......
...@@ -17,7 +17,7 @@ namespace COMMON ...@@ -17,7 +17,7 @@ namespace COMMON
class WorkSpacePool class WorkSpacePool
{ {
public: public:
WorkSpacePool(int p_maxCheck, int p_vectorCount); WorkSpacePool(int p_maxCheck, SizeType p_vectorCount);
virtual ~WorkSpacePool(); virtual ~WorkSpacePool();
...@@ -34,7 +34,7 @@ private: ...@@ -34,7 +34,7 @@ private:
int m_maxCheck; int m_maxCheck;
int m_vectorCount; SizeType m_vectorCount;
}; };
} }
......
...@@ -4,53 +4,223 @@ ...@@ -4,53 +4,223 @@
#ifndef _SPTAG_COMMONDATASTRUCTURE_H_ #ifndef _SPTAG_COMMONDATASTRUCTURE_H_
#define _SPTAG_COMMONDATASTRUCTURE_H_ #define _SPTAG_COMMONDATASTRUCTURE_H_
#include "Common.h" #include "inc/Core/Common.h"
namespace SPTAG namespace SPTAG
{ {
class ByteArray template<typename T>
class Array
{ {
public: public:
ByteArray(); Array();
ByteArray(ByteArray&& p_right); Array(T* p_array, std::size_t p_length, bool p_transferOwnership);
Array(T* p_array, std::size_t p_length, std::shared_ptr<T> p_dataHolder);
ByteArray(std::uint8_t* p_array, std::size_t p_length, bool p_transferOnwership); Array(Array<T>&& p_right);
ByteArray(std::uint8_t* p_array, std::size_t p_length, std::shared_ptr<std::uint8_t> p_dataHolder); Array(const Array<T>& p_right);
ByteArray(const ByteArray& p_right); Array<T>& operator= (Array<T>&& p_right);
ByteArray& operator= (const ByteArray& p_right); Array<T>& operator= (const Array<T>& p_right);
ByteArray& operator= (ByteArray&& p_right); T& operator[] (std::size_t p_index);
~ByteArray(); const T& operator[] (std::size_t p_index) const;
static ByteArray Alloc(std::size_t p_length); ~Array();
std::uint8_t* Data() const; T* Data() const;
std::size_t Length() const; std::size_t Length() const;
void SetData(std::uint8_t* p_array, std::size_t p_length);
std::shared_ptr<std::uint8_t> DataHolder() const; std::shared_ptr<T> DataHolder() const;
void Set(T* p_array, std::size_t p_length, bool p_transferOwnership);
void Clear(); void Clear();
const static ByteArray c_empty; static Array<T> Alloc(std::size_t p_length);
const static Array<T> c_empty;
private: private:
std::uint8_t* m_data; T* m_data;
std::size_t m_length; std::size_t m_length;
// Notice this is holding an array. Set correct deleter for this. // Notice this is holding an array. Set correct deleter for this.
std::shared_ptr<std::uint8_t> m_dataHolder; std::shared_ptr<T> m_dataHolder;
}; };
template<typename T>
const Array<T> Array<T>::c_empty;
template<typename T>
Array<T>::Array()
: m_data(nullptr),
m_length(0)
{
}
template<typename T>
Array<T>::Array(T* p_array, std::size_t p_length, bool p_transferOnwership)
: m_data(p_array),
m_length(p_length)
{
if (p_transferOnwership)
{
m_dataHolder.reset(m_data, std::default_delete<T[]>());
}
}
template<typename T>
Array<T>::Array(T* p_array, std::size_t p_length, std::shared_ptr<T> p_dataHolder)
: m_data(p_array),
m_length(p_length),
m_dataHolder(std::move(p_dataHolder))
{
}
template<typename T>
Array<T>::Array(Array<T>&& p_right)
: m_data(p_right.m_data),
m_length(p_right.m_length),
m_dataHolder(std::move(p_right.m_dataHolder))
{
}
template<typename T>
Array<T>::Array(const Array<T>& p_right)
: m_data(p_right.m_data),
m_length(p_right.m_length),
m_dataHolder(p_right.m_dataHolder)
{
}
template<typename T>
Array<T>&
Array<T>::operator= (Array<T>&& p_right)
{
m_data = p_right.m_data;
m_length = p_right.m_length;
m_dataHolder = std::move(p_right.m_dataHolder);
return *this;
}
template<typename T>
Array<T>&
Array<T>::operator= (const Array<T>& p_right)
{
m_data = p_right.m_data;
m_length = p_right.m_length;
m_dataHolder = p_right.m_dataHolder;
return *this;
}
template<typename T>
T&
Array<T>::operator[] (std::size_t p_index)
{
return m_data[p_index];
}
template<typename T>
const T&
Array<T>::operator[] (std::size_t p_index) const
{
return m_data[p_index];
}
template<typename T>
Array<T>::~Array()
{
}
template<typename T>
T*
Array<T>::Data() const
{
return m_data;
}
template<typename T>
std::size_t
Array<T>::Length() const
{
return m_length;
}
template<typename T>
std::shared_ptr<T>
Array<T>::DataHolder() const
{
return m_dataHolder;
}
template<typename T>
void
Array<T>::Set(T* p_array, std::size_t p_length, bool p_transferOwnership)
{
m_data = p_array;
m_length = p_length;
if (p_transferOwnership)
{
m_dataHolder.reset(m_data, std::default_delete<T[]>());
}
}
template<typename T>
void
Array<T>::Clear()
{
m_data = nullptr;
m_length = 0;
m_dataHolder.reset();
}
template<typename T>
Array<T>
Array<T>::Alloc(std::size_t p_length)
{
Array<T> arr;
if (0 == p_length)
{
return arr;
}
arr.m_dataHolder.reset(new T[p_length], std::default_delete<T[]>());
arr.m_length = p_length;
arr.m_data = arr.m_dataHolder.get();
return arr;
}
typedef Array<std::uint8_t> ByteArray;
} // namespace SPTAG } // namespace SPTAG
#endif // _SPTAG_COMMONDATASTRUCTURE_H_ #endif // _SPTAG_COMMONDATASTRUCTURE_H_
...@@ -28,6 +28,8 @@ DefineErrorCode(FailedOpenFile, 0x0002) ...@@ -28,6 +28,8 @@ DefineErrorCode(FailedOpenFile, 0x0002)
DefineErrorCode(FailedCreateFile, 0x0003) DefineErrorCode(FailedCreateFile, 0x0003)
DefineErrorCode(ParamNotFound, 0x0010) DefineErrorCode(ParamNotFound, 0x0010)
DefineErrorCode(FailedParseValue, 0x0011) DefineErrorCode(FailedParseValue, 0x0011)
DefineErrorCode(MemoryOverFlow, 0x0012)
DefineErrorCode(LackOfInputs, 0x0013)
// 0x1000 ~ 0x1FFF Index Build Status // 0x1000 ~ 0x1FFF Index Build Status
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
#include "../Common/WorkSpacePool.h" #include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h" #include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/KDTree.h" #include "../Common/KDTree.h"
#include "inc/Helper/ConcurrentSet.h"
#include "inc/Helper/StringConvert.h" #include "inc/Helper/StringConvert.h"
#include "inc/Helper/SimpleIniReader.h" #include "inc/Helper/SimpleIniReader.h"
#include <functional> #include <functional>
#include <mutex> #include <mutex>
#include <tbb/concurrent_unordered_set.h>
namespace SPTAG namespace SPTAG
{ {
...@@ -48,14 +48,16 @@ namespace SPTAG ...@@ -48,14 +48,16 @@ namespace SPTAG
std::string m_sKDTFilename; std::string m_sKDTFilename;
std::string m_sGraphFilename; std::string m_sGraphFilename;
std::string m_sDataPointsFilename; std::string m_sDataPointsFilename;
std::string m_sDeleteDataPointsFilename;
std::mutex m_dataLock; // protect data and graph std::mutex m_dataAddLock; // protect data and graph
tbb::concurrent_unordered_set<int> m_deletedID; Helper::Concurrent::ConcurrentSet<SizeType> m_deletedID;
float m_fDeletePercentageForRefine;
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool; std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
int m_iNumberOfThreads; int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod; DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length); float(*m_fComputeDistance)(const T* pX, const T* pY, DimensionType length);
int m_iMaxCheck; int m_iMaxCheck;
int m_iThresholdOfNumberOfContinuousNoBetterPropagation; int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
...@@ -63,20 +65,21 @@ namespace SPTAG ...@@ -63,20 +65,21 @@ namespace SPTAG
int m_iNumberOfOtherDynamicPivots; int m_iNumberOfOtherDynamicPivots;
public: public:
Index() Index()
{ {
#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ #define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \
VarName = DefaultValue; \ VarName = DefaultValue; \
#include "inc/Core/KDT/ParameterDefinitionList.h" #include "inc/Core/KDT/ParameterDefinitionList.h"
#undef DefineKDTParameter #undef DefineKDTParameter
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod); m_pSamples.SetName("Vector");
} m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
}
~Index() {} ~Index() {}
inline int GetNumSamples() const { return m_pSamples.R(); } inline SizeType GetNumSamples() const { return m_pSamples.R(); }
inline int GetFeatureDim() const { return m_pSamples.C(); } inline DimensionType GetFeatureDim() const { return m_pSamples.C(); }
inline int GetCurrMaxCheck() const { return m_iMaxCheck; } inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
inline int GetNumThreads() const { return m_iNumberOfThreads; } inline int GetNumThreads() const { return m_iNumberOfThreads; }
...@@ -85,25 +88,41 @@ namespace SPTAG ...@@ -85,25 +88,41 @@ namespace SPTAG
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); } inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); } inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; } inline const void* GetSample(const SizeType idx) const { return (void*)m_pSamples[idx]; }
inline bool ContainSample(const SizeType idx) const { return !m_deletedID.contains(idx); }
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension); inline bool NeedRefine() const { return m_deletedID.size() >= (size_t)(GetNumSamples() * m_fDeletePercentageForRefine); }
std::shared_ptr<std::vector<std::uint64_t>> BufferSize() const
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen); {
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs); std::shared_ptr<std::vector<std::uint64_t>> buffersize(new std::vector<std::uint64_t>);
buffersize->push_back(m_pSamples.BufferSize());
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout); buffersize->push_back(m_pTrees.BufferSize());
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader); buffersize->push_back(m_pGraph.BufferSize());
buffersize->push_back(m_deletedID.bufferSize());
return std::move(buffersize);
}
ErrorCode SaveConfig(std::ostream& p_configout) const;
ErrorCode SaveIndexData(const std::string& p_folderPath);
ErrorCode SaveIndexData(const std::vector<std::ostream*>& p_indexStreams);
ErrorCode LoadConfig(Helper::IniReader& p_reader);
ErrorCode LoadIndexData(const std::string& p_folderPath);
ErrorCode LoadIndexDataFromMemory(const std::vector<ByteArray>& p_indexBlobs);
ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension);
ErrorCode SearchIndex(QueryResult &p_query) const; ErrorCode SearchIndex(QueryResult &p_query) const;
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension); ErrorCode AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start = nullptr);
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum); ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum);
ErrorCode DeleteIndex(const SizeType& p_id);
ErrorCode SetParameter(const char* p_param, const char* p_value); ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const; std::string GetParameter(const char* p_param) const;
private:
ErrorCode RefineIndex(const std::string& p_folderPath); ErrorCode RefineIndex(const std::string& p_folderPath);
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const; ErrorCode RefineIndex(const std::vector<std::ostream*>& p_indexStreams);
private:
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet<SizeType> &p_deleted) const;
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const; void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
}; };
} // namespace KDT } // namespace KDT
......
...@@ -7,16 +7,17 @@ ...@@ -7,16 +7,17 @@
DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath") DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
DefineKDTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath") DefineKDTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
DefineKDTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath") DefineKDTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
DefineKDTParameter(m_sDeleteDataPointsFilename, std::string, std::string("deletes.bin"), "DeleteVectorFilePath")
DefineKDTParameter(m_pTrees.m_iTreeNumber, int, 1L, "KDTNumber") DefineKDTParameter(m_pTrees.m_iTreeNumber, int, 1L, "KDTNumber")
DefineKDTParameter(m_pTrees.m_numTopDimensionKDTSplit, int, 5L, "NumTopDimensionKDTSplit") DefineKDTParameter(m_pTrees.m_numTopDimensionKDTSplit, int, 5L, "NumTopDimensionKDTSplit")
DefineKDTParameter(m_pTrees.m_iSamples, int, 100L, "NumSamplesKDTSplitConsideration") DefineKDTParameter(m_pTrees.m_iSamples, int, 100L, "Samples")
DefineKDTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber") DefineKDTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber")
DefineKDTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize") DefineKDTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
DefineKDTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTPTSplit") DefineKDTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTPTSplit")
DefineKDTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize") DefineKDTParameter(m_pGraph.m_iNeighborhoodSize, DimensionType, 32L, "NeighborhoodSize")
DefineKDTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale") DefineKDTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
DefineKDTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale") DefineKDTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
DefineKDTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations") DefineKDTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
...@@ -26,6 +27,7 @@ DefineKDTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckFor ...@@ -26,6 +27,7 @@ DefineKDTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckFor
DefineKDTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads") DefineKDTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
DefineKDTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod") DefineKDTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
DefineKDTParameter(m_fDeletePercentageForRefine, float, 0.4F, "DeletePercentageForRefine")
DefineKDTParameter(m_iMaxCheck, int, 8192L, "MaxCheck") DefineKDTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
DefineKDTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation") DefineKDTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
DefineKDTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots") DefineKDTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
......
...@@ -19,23 +19,23 @@ public: ...@@ -19,23 +19,23 @@ public:
virtual ~MetadataSet(); virtual ~MetadataSet();
virtual ByteArray GetMetadata(IndexType p_vectorID) const = 0; virtual ByteArray GetMetadata(SizeType p_vectorID) const = 0;
virtual SizeType Count() const = 0; virtual SizeType Count() const = 0;
virtual bool Available() const = 0; virtual bool Available() const = 0;
virtual void AddBatch(MetadataSet& data) = 0; virtual std::pair<std::uint64_t, std::uint64_t> BufferSize() const = 0;
virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0; virtual void AddBatch(MetadataSet& data) = 0;
virtual ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len) = 0; virtual ErrorCode SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut) = 0;
virtual ErrorCode LoadMetadataFromMemory(void *pGraphMemFile) = 0; virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0;
virtual ErrorCode RefineMetadata(std::vector<int>& indices, const std::string& p_folderPath); virtual ErrorCode RefineMetadata(std::vector<SizeType>& indices, std::ostream& p_metaOut, std::ostream& p_metaIndexOut);
static ErrorCode MetaCopy(const std::string& p_src, const std::string& p_dst); virtual ErrorCode RefineMetadata(std::vector<SizeType>& indices, const std::string& p_metaFile, const std::string& p_metaindexFile);
}; };
...@@ -46,19 +46,20 @@ public: ...@@ -46,19 +46,20 @@ public:
~FileMetadataSet(); ~FileMetadataSet();
ByteArray GetMetadata(IndexType p_vectorID) const; ByteArray GetMetadata(SizeType p_vectorID) const;
SizeType Count() const; SizeType Count() const;
bool Available() const; bool Available() const;
std::pair<std::uint64_t, std::uint64_t> BufferSize() const;
void AddBatch(MetadataSet& data); void AddBatch(MetadataSet& data);
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile); ErrorCode SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut);
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len); ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
private: private:
std::ifstream* m_fp = nullptr; std::ifstream* m_fp = nullptr;
...@@ -77,25 +78,24 @@ private: ...@@ -77,25 +78,24 @@ private:
class MemMetadataSet : public MetadataSet class MemMetadataSet : public MetadataSet
{ {
public: public:
MemMetadataSet() = default;
MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count); MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count);
~MemMetadataSet(); ~MemMetadataSet();
ByteArray GetMetadata(IndexType p_vectorID) const; ByteArray GetMetadata(SizeType p_vectorID) const;
SizeType Count() const; SizeType Count() const;
bool Available() const; bool Available() const;
std::pair<std::uint64_t, std::uint64_t> BufferSize() const;
void AddBatch(MetadataSet& data); void AddBatch(MetadataSet& data);
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile); ErrorCode SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut);
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len); ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
private: private:
std::vector<std::uint64_t> m_offsets; std::vector<std::uint64_t> m_offsets;
......
...@@ -4,24 +4,13 @@ ...@@ -4,24 +4,13 @@
#ifndef _SPTAG_SEARCHQUERY_H_ #ifndef _SPTAG_SEARCHQUERY_H_
#define _SPTAG_SEARCHQUERY_H_ #define _SPTAG_SEARCHQUERY_H_
#include "CommonDataStructure.h" #include "SearchResult.h"
#include <cstring> #include <cstring>
namespace SPTAG namespace SPTAG
{ {
struct BasicResult
{
int VID;
float Dist;
BasicResult() : VID(-1), Dist(MaxDist) {}
BasicResult(int p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {}
};
// Space to save temporary answer, similar with TopKCache // Space to save temporary answer, similar with TopKCache
class QueryResult class QueryResult
{ {
...@@ -38,39 +27,26 @@ public: ...@@ -38,39 +27,26 @@ public:
QueryResult(const void* p_target, int p_resultNum, bool p_withMeta) QueryResult(const void* p_target, int p_resultNum, bool p_withMeta)
: m_target(nullptr),
m_resultNum(0),
m_withMeta(false)
{ {
Init(p_target, p_resultNum, p_withMeta); Init(p_target, p_resultNum, p_withMeta);
} }
QueryResult(const void* p_target, int p_resultNum, std::vector<BasicResult>& p_results) QueryResult(const void* p_target, int p_resultNum, bool p_withMeta, BasicResult* p_results)
: m_target(p_target), : m_target(p_target),
m_resultNum(p_resultNum), m_resultNum(p_resultNum),
m_withMeta(false) m_withMeta(p_withMeta)
{ {
p_results.resize(p_resultNum); m_results.Set(p_results, p_resultNum, false);
m_results.reset(p_results.data());
} }
QueryResult(const QueryResult& p_other) QueryResult(const QueryResult& p_other)
: m_target(p_other.m_target),
m_resultNum(p_other.m_resultNum),
m_withMeta(p_other.m_withMeta)
{ {
Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta);
if (m_resultNum > 0) if (m_resultNum > 0)
{ {
m_results.reset(new BasicResult[m_resultNum]); std::copy(p_other.m_results.Data(), p_other.m_results.Data() + m_resultNum, m_results.Data());
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum);
if (m_withMeta)
{
m_metadatas.reset(new ByteArray[m_resultNum]);
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
}
} }
} }
...@@ -78,14 +54,9 @@ public: ...@@ -78,14 +54,9 @@ public:
QueryResult& operator=(const QueryResult& p_other) QueryResult& operator=(const QueryResult& p_other)
{ {
Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta); Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta);
if (m_resultNum > 0) if (m_resultNum > 0)
{ {
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum); std::copy(p_other.m_results.Data(), p_other.m_results.Data() + m_resultNum, m_results.Data());
if (m_withMeta)
{
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
}
} }
return *this; return *this;
...@@ -100,18 +71,10 @@ public: ...@@ -100,18 +71,10 @@ public:
inline void Init(const void* p_target, int p_resultNum, bool p_withMeta) inline void Init(const void* p_target, int p_resultNum, bool p_withMeta)
{ {
m_target = p_target; m_target = p_target;
if (p_resultNum > m_resultNum)
{
m_results.reset(new BasicResult[p_resultNum]);
}
if (p_withMeta && (!m_withMeta || p_resultNum > m_resultNum))
{
m_metadatas.reset(new ByteArray[p_resultNum]);
}
m_resultNum = p_resultNum; m_resultNum = p_resultNum;
m_withMeta = p_withMeta; m_withMeta = p_withMeta;
m_results = Array<BasicResult>::Alloc(p_resultNum);
} }
...@@ -135,11 +98,11 @@ public: ...@@ -135,11 +98,11 @@ public:
inline BasicResult* GetResult(int i) const inline BasicResult* GetResult(int i) const
{ {
return i < m_resultNum ? m_results.get() + i : nullptr; return i < m_resultNum ? m_results.Data() + i : nullptr;
} }
inline void SetResult(int p_index, int p_VID, float p_dist) inline void SetResult(int p_index, SizeType p_VID, float p_dist)
{ {
if (p_index < m_resultNum) if (p_index < m_resultNum)
{ {
...@@ -151,7 +114,7 @@ public: ...@@ -151,7 +114,7 @@ public:
inline BasicResult* GetResults() const inline BasicResult* GetResults() const
{ {
return m_results.get(); return m_results.Data();
} }
...@@ -165,7 +128,7 @@ public: ...@@ -165,7 +128,7 @@ public:
{ {
if (p_index < m_resultNum && m_withMeta) if (p_index < m_resultNum && m_withMeta)
{ {
return m_metadatas[p_index]; return m_results[p_index].Meta;
} }
return ByteArray::c_empty; return ByteArray::c_empty;
...@@ -176,7 +139,7 @@ public: ...@@ -176,7 +139,7 @@ public:
{ {
if (p_index < m_resultNum && m_withMeta) if (p_index < m_resultNum && m_withMeta)
{ {
m_metadatas[p_index] = std::move(p_metadata); m_results[p_index].Meta = std::move(p_metadata);
} }
} }
...@@ -187,39 +150,32 @@ public: ...@@ -187,39 +150,32 @@ public:
{ {
m_results[i].VID = -1; m_results[i].VID = -1;
m_results[i].Dist = MaxDist; m_results[i].Dist = MaxDist;
} m_results[i].Meta.Clear();
if (m_withMeta)
{
for (int i = 0; i < m_resultNum; i++)
{
m_metadatas[i].Clear();
}
} }
} }
iterator begin() iterator begin()
{ {
return m_results.get(); return m_results.Data();
} }
iterator end() iterator end()
{ {
return m_results.get() + m_resultNum; return m_results.Data() + m_resultNum;
} }
const_iterator begin() const const_iterator begin() const
{ {
return m_results.get(); return m_results.Data();
} }
const_iterator end() const const_iterator end() const
{ {
return m_results.get() + m_resultNum; return m_results.Data() + m_resultNum;
} }
...@@ -230,9 +186,7 @@ protected: ...@@ -230,9 +186,7 @@ protected:
bool m_withMeta; bool m_withMeta;
std::unique_ptr<BasicResult[]> m_results; Array<BasicResult> m_results;
std::unique_ptr<ByteArray[]> m_metadatas;
}; };
} // namespace SPTAG } // namespace SPTAG
......
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SEARCHRESULT_H_
#define _SPTAG_SEARCHRESULT_H_
#include "CommonDataStructure.h"
namespace SPTAG
{
struct BasicResult
{
SizeType VID;
float Dist;
ByteArray Meta;
BasicResult() : VID(-1), Dist(MaxDist) {}
BasicResult(SizeType p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {}
BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta) : VID(p_vid), Dist(p_dist), Meta(p_meta) {}
};
} // namespace SPTAG
#endif // _SPTAG_SEARCHRESULT_H_
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "MetadataSet.h" #include "MetadataSet.h"
#include "inc/Helper/SimpleIniReader.h" #include "inc/Helper/SimpleIniReader.h"
#include <unordered_map>
namespace SPTAG namespace SPTAG
{ {
...@@ -20,59 +22,58 @@ public: ...@@ -20,59 +22,58 @@ public:
virtual ~VectorIndex(); virtual ~VectorIndex();
virtual ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout) = 0; virtual ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension) = 0;
virtual ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader) = 0;
virtual ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen) = 0; virtual ErrorCode AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start = nullptr) = 0;
virtual ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs) = 0; virtual ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum) = 0;
virtual ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) = 0;
virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0; virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0;
virtual ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) = 0;
virtual ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum) = 0;
//virtual ErrorCode AddIndexWithID(const void* p_vector, const int& p_id) = 0;
//virtual ErrorCode DeleteIndexWithID(const void* p_vector, const int& p_id) = 0;
virtual float ComputeDistance(const void* pX, const void* pY) const = 0; virtual float ComputeDistance(const void* pX, const void* pY) const = 0;
virtual const void* GetSample(const int idx) const = 0; virtual const void* GetSample(const SizeType idx) const = 0;
virtual int GetFeatureDim() const = 0; virtual bool ContainSample(const SizeType idx) const = 0;
virtual int GetNumSamples() const = 0; virtual bool NeedRefine() const = 0;
virtual DimensionType GetFeatureDim() const = 0;
virtual SizeType GetNumSamples() const = 0;
virtual DistCalcMethod GetDistCalcMethod() const = 0; virtual DistCalcMethod GetDistCalcMethod() const = 0;
virtual IndexAlgoType GetIndexAlgoType() const = 0; virtual IndexAlgoType GetIndexAlgoType() const = 0;
virtual VectorValueType GetVectorValueType() const = 0; virtual VectorValueType GetVectorValueType() const = 0;
virtual int GetNumThreads() const = 0;
virtual std::string GetParameter(const char* p_param) const = 0; virtual std::string GetParameter(const char* p_param) const = 0;
virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0; virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0;
virtual std::shared_ptr<std::vector<std::uint64_t>> CalculateBufferSize() const;
virtual ErrorCode LoadIndex(const std::string& p_config, const std::vector<ByteArray>& p_indexBlobs);
virtual ErrorCode LoadIndex(const std::string& p_folderPath); virtual ErrorCode LoadIndex(const std::string& p_folderPath);
virtual ErrorCode SaveIndex(std::string& p_config, const std::vector<ByteArray>& p_indexBlobs);
virtual ErrorCode SaveIndex(const std::string& p_folderPath); virtual ErrorCode SaveIndex(const std::string& p_folderPath);
virtual ErrorCode BuildIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet); virtual ErrorCode BuildIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet, bool p_withMetaIndex = false);
virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, std::vector<BasicResult>& p_results) const;
virtual ErrorCode AddIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet); virtual ErrorCode AddIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet);
virtual ErrorCode DeleteIndex(ByteArray p_meta);
virtual const void* GetSample(ByteArray p_meta);
virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, bool p_withMeta, BasicResult* p_results) const;
virtual std::string GetParameter(const std::string& p_param) const; virtual std::string GetParameter(const std::string& p_param) const;
virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value); virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value);
virtual ByteArray GetMetadata(IndexType p_vectorID) const; virtual ByteArray GetMetadata(SizeType p_vectorID) const;
virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath); virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath);
virtual std::string GetIndexName() const virtual std::string GetIndexName() const
{ {
if (m_sIndexName == "") if (m_sIndexName == "") return Helper::Convert::ConvertToString(GetIndexAlgoType());
return Helper::Convert::ConvertToString(GetIndexAlgoType());
return m_sIndexName; return m_sIndexName;
} }
virtual void SetIndexName(std::string p_name) { m_sIndexName = p_name; } virtual void SetIndexName(std::string p_name) { m_sIndexName = p_name; }
...@@ -83,9 +84,42 @@ public: ...@@ -83,9 +84,42 @@ public:
static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr<VectorIndex>& p_vectorIndex); static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr<VectorIndex>& p_vectorIndex);
static ErrorCode LoadIndex(const std::string& p_config, const std::vector<ByteArray>& p_indexBlobs, std::shared_ptr<VectorIndex>& p_vectorIndex);
protected:
virtual std::shared_ptr<std::vector<std::uint64_t>> BufferSize() const = 0;
virtual ErrorCode SaveConfig(std::ostream& p_configout) const = 0;
virtual ErrorCode SaveIndexData(const std::string& p_folderPath) = 0;
virtual ErrorCode SaveIndexData(const std::vector<std::ostream*>& p_indexStreams) = 0;
virtual ErrorCode LoadConfig(Helper::IniReader& p_reader) = 0;
virtual ErrorCode LoadIndexData(const std::string& p_folderPath) = 0;
virtual ErrorCode LoadIndexDataFromMemory(const std::vector<ByteArray>& p_indexBlobs) = 0;
virtual ErrorCode DeleteIndex(const SizeType& p_id) = 0;
virtual ErrorCode RefineIndex(const std::string& p_folderPath) = 0;
virtual ErrorCode RefineIndex(const std::vector<std::ostream*>& p_indexStreams) = 0;
private:
void BuildMetaMapping();
ErrorCode LoadIndexConfig(Helper::IniReader& p_reader);
ErrorCode SaveIndexConfig(std::ostream& p_configOut);
protected: protected:
std::string m_sIndexName; std::string m_sIndexName;
std::string m_sMetadataFile = "metadata.bin";
std::string m_sMetadataIndexFile = "metadataIndex.bin";
std::shared_ptr<MetadataSet> m_pMetadata; std::shared_ptr<MetadataSet> m_pMetadata;
std::unique_ptr<std::unordered_map<std::string, SizeType>> m_pMetaToVec;
}; };
......
...@@ -18,11 +18,11 @@ public: ...@@ -18,11 +18,11 @@ public:
virtual VectorValueType GetValueType() const = 0; virtual VectorValueType GetValueType() const = 0;
virtual void* GetVector(IndexType p_vectorID) const = 0; virtual void* GetVector(SizeType p_vectorID) const = 0;
virtual void* GetData() const = 0; virtual void* GetData() const = 0;
virtual SizeType Dimension() const = 0; virtual DimensionType Dimension() const = 0;
virtual SizeType Count() const = 0; virtual SizeType Count() const = 0;
...@@ -37,18 +37,18 @@ class BasicVectorSet : public VectorSet ...@@ -37,18 +37,18 @@ class BasicVectorSet : public VectorSet
public: public:
BasicVectorSet(const ByteArray& p_bytesArray, BasicVectorSet(const ByteArray& p_bytesArray,
VectorValueType p_valueType, VectorValueType p_valueType,
SizeType p_dimension, DimensionType p_dimension,
SizeType p_vectorCount); SizeType p_vectorCount);
virtual ~BasicVectorSet(); virtual ~BasicVectorSet();
virtual VectorValueType GetValueType() const; virtual VectorValueType GetValueType() const;
virtual void* GetVector(IndexType p_vectorID) const; virtual void* GetVector(SizeType p_vectorID) const;
virtual void* GetData() const; virtual void* GetData() const;
virtual SizeType Dimension() const; virtual DimensionType Dimension() const;
virtual SizeType Count() const; virtual SizeType Count() const;
...@@ -61,7 +61,7 @@ private: ...@@ -61,7 +61,7 @@ private:
VectorValueType m_valueType; VectorValueType m_valueType;
SizeType m_dimension; DimensionType m_dimension;
SizeType m_vectorCount; SizeType m_vectorCount;
......
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_BUFFERSTREAM_H_
#define _SPTAG_HELPER_BUFFERSTREAM_H_
#include <streambuf>
#include <ostream>
#include <memory>
namespace SPTAG
{
namespace Helper
{
struct streambuf : public std::basic_streambuf<char>
{
streambuf(char* buffer, size_t size)
{
setp(buffer, buffer + size);
}
};
class obufferstream : public std::ostream
{
public:
obufferstream(streambuf* buf, bool transferOwnership) : std::ostream(buf)
{
if (transferOwnership)
m_bufHolder.reset(buf, std::default_delete<streambuf>());
}
private:
std::shared_ptr<streambuf> m_bufHolder;
};
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_BUFFERSTREAM_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_CONCURRENTSET_H_
#define _SPTAG_HELPER_CONCURRENTSET_H_
#include <shared_mutex>
#include <unordered_set>
namespace SPTAG
{
namespace Helper
{
namespace Concurrent
{
template <typename T>
class ConcurrentSet
{
public:
ConcurrentSet();
~ConcurrentSet();
size_t size() const;
bool contains(const T& key) const;
void insert(const T& key);
std::shared_timed_mutex& getLock();
bool save(std::ostream& output);
bool save(std::string filename);
bool load(std::string filename);
bool load(char* pmemoryFile);
std::uint64_t bufferSize() const;
private:
std::unique_ptr<std::shared_timed_mutex> m_lock;
std::unordered_set<T> m_data;
};
template<typename T>
ConcurrentSet<T>::ConcurrentSet()
{
m_lock.reset(new std::shared_timed_mutex);
}
template<typename T>
ConcurrentSet<T>::~ConcurrentSet()
{
}
template<typename T>
size_t ConcurrentSet<T>::size() const
{
std::shared_lock<std::shared_timed_mutex> lock(*m_lock);
return m_data.size();
}
template<typename T>
bool ConcurrentSet<T>::contains(const T& key) const
{
std::shared_lock<std::shared_timed_mutex> lock(*m_lock);
return (m_data.find(key) != m_data.end());
}
template<typename T>
void ConcurrentSet<T>::insert(const T& key)
{
std::unique_lock<std::shared_timed_mutex> lock(*m_lock);
m_data.insert(key);
}
template<typename T>
std::shared_timed_mutex& ConcurrentSet<T>::getLock()
{
return *m_lock;
}
template<typename T>
std::uint64_t ConcurrentSet<T>::bufferSize() const
{
return sizeof(SizeType) + sizeof(T) * m_data.size();
}
template<typename T>
bool ConcurrentSet<T>::save(std::ostream& output)
{
SizeType count = (SizeType)m_data.size();
output.write((char*)&count, sizeof(SizeType));
for (auto iter = m_data.begin(); iter != m_data.end(); iter++)
output.write((char*)&(*iter), sizeof(T));
std::cout << "Save DeleteID (" << count << ") Finish!" << std::endl;
return true;
}
template<typename T>
bool ConcurrentSet<T>::save(std::string filename)
{
std::cout << "Save DeleteID To " << filename << std::endl;
std::ofstream output(filename, std::ios::binary);
if (!output.is_open()) return false;
save(output);
output.close();
return true;
}
template<typename T>
bool ConcurrentSet<T>::load(std::string filename)
{
std::cout << "Load DeleteID From " << filename << std::endl;
std::ifstream input(filename, std::ios::binary);
if (!input.is_open()) return false;
SizeType count;
T ID;
input.read((char*)&count, sizeof(SizeType));
for (SizeType i = 0; i < count; i++)
{
input.read((char*)&ID, sizeof(T));
m_data.insert(ID);
}
input.close();
std::cout << "Load DeleteID (" << count << ") Finish!" << std::endl;
return true;
}
template<typename T>
bool ConcurrentSet<T>::load(char* pmemoryFile)
{
SizeType count;
count = *((SizeType*)pmemoryFile);
pmemoryFile += sizeof(SizeType);
m_data.insert((T*)pmemoryFile, ((T*)pmemoryFile) + count);
pmemoryFile += sizeof(T) * count;
std::cout << "Load DeleteID (" << count << ") Finish!" << std::endl;
return true;
}
}
}
}
#endif // _SPTAG_HELPER_CONCURRENTSET_H_
\ No newline at end of file
...@@ -31,6 +31,8 @@ public: ...@@ -31,6 +31,8 @@ public:
ErrorCode LoadIniFile(const std::string& p_iniFilePath); ErrorCode LoadIniFile(const std::string& p_iniFilePath);
ErrorCode LoadIni(std::istream& p_input);
bool DoesSectionExist(const std::string& p_section) const; bool DoesSectionExist(const std::string& p_section) const;
bool DoesParameterExist(const std::string& p_section, const std::string& p_param) const; bool DoesParameterExist(const std::string& p_section, const std::string& p_param) const;
......
// Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. // Licensed under the MIT License.
#ifndef _SPTAG_INDEXBUILDER_VECTORSETREADER_H_ #ifndef _SPTAG_HELPER_VECTORSETREADER_H_
#define _SPTAG_INDEXBUILDER_VECTORSETREADER_H_ #define _SPTAG_HELPER_VECTORSETREADER_H_
#include "inc/Core/Common.h" #include "inc/Core/Common.h"
#include "inc/Core/VectorSet.h" #include "inc/Core/VectorSet.h"
#include "inc/Core/MetadataSet.h" #include "inc/Core/MetadataSet.h"
#include "Options.h" #include "inc/Helper/ArgumentsParser.h"
#include <memory> #include <memory>
namespace SPTAG namespace SPTAG
{ {
namespace IndexBuilder namespace Helper
{ {
class ReaderOptions : public ArgumentsParser
{
public:
ReaderOptions(VectorValueType p_valueType, DimensionType p_dimension, std::string p_vectorDelimiter = "|", std::uint32_t p_threadNum = 32);
~ReaderOptions();
std::uint32_t m_threadNum;
DimensionType m_dimension;
std::string m_vectorDelimiter;
SPTAG::VectorValueType m_inputValueType;
};
class VectorSetReader class VectorSetReader
{ {
public: public:
VectorSetReader(std::shared_ptr<BuilderOptions> p_options); VectorSetReader(std::shared_ptr<ReaderOptions> p_options);
virtual ~VectorSetReader(); virtual ~VectorSetReader();
...@@ -29,15 +45,15 @@ public: ...@@ -29,15 +45,15 @@ public:
virtual std::shared_ptr<MetadataSet> GetMetadataSet() const = 0; virtual std::shared_ptr<MetadataSet> GetMetadataSet() const = 0;
static std::shared_ptr<VectorSetReader> CreateInstance(std::shared_ptr<BuilderOptions> p_options); static std::shared_ptr<VectorSetReader> CreateInstance(std::shared_ptr<ReaderOptions> p_options);
protected: protected:
std::shared_ptr<BuilderOptions> m_options; std::shared_ptr<ReaderOptions> m_options;
}; };
} // namespace IndexBuilder } // namespace Helper
} // namespace SPTAG } // namespace SPTAG
#endif // _SPTAG_INDEXBUILDER_VECTORSETREADER_H_ #endif // _SPTAG_HELPER_VECTORSETREADER_H_
// Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. // Licensed under the MIT License.
#ifndef _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULTREADER_H_ #ifndef _SPTAG_HELPER_VECTORSETREADERS_DEFAULTREADER_H_
#define _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULTREADER_H_ #define _SPTAG_HELPER_VECTORSETREADERS_DEFAULTREADER_H_
#include "../VectorSetReader.h" #include "../VectorSetReader.h"
#include "inc/Helper/Concurrent.h" #include "inc/Helper/Concurrent.h"
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
namespace SPTAG namespace SPTAG
{ {
namespace IndexBuilder namespace Helper
{ {
class DefaultReader : public VectorSetReader class DefaultReader : public VectorSetReader
{ {
public: public:
DefaultReader(std::shared_ptr<BuilderOptions> p_options); DefaultReader(std::shared_ptr<ReaderOptions> p_options);
virtual ~DefaultReader(); virtual ~DefaultReader();
...@@ -44,7 +44,7 @@ private: ...@@ -44,7 +44,7 @@ private:
template<typename DataType> template<typename DataType>
bool TranslateVector(char* p_str, DataType* p_vector) bool TranslateVector(char* p_str, DataType* p_vector)
{ {
std::uint32_t eleCount = 0; DimensionType eleCount = 0;
char* next = p_str; char* next = p_str;
while ((*next) != '\0') while ((*next) != '\0')
{ {
...@@ -85,11 +85,11 @@ private: ...@@ -85,11 +85,11 @@ private:
std::size_t m_subTaskBlocksize; std::size_t m_subTaskBlocksize;
std::atomic<std::uint32_t> m_totalRecordCount; std::atomic<SizeType> m_totalRecordCount;
std::atomic<std::size_t> m_totalRecordVectorBytes; std::atomic<std::size_t> m_totalRecordVectorBytes;
std::vector<std::uint32_t> m_subTaskRecordCount; std::vector<SizeType> m_subTaskRecordCount;
std::string m_vectorOutput; std::string m_vectorOutput;
...@@ -102,7 +102,7 @@ private: ...@@ -102,7 +102,7 @@ private:
} // namespace IndexBuilder } // namespace Helper
} // namespace SPTAG } // namespace SPTAG
#endif // _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULT_H_ #endif // _SPTAG_HELPER_VECTORSETREADERS_DEFAULT_H_
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#define _SPTAG_INDEXBUILDER_OPTIONS_H_ #define _SPTAG_INDEXBUILDER_OPTIONS_H_
#include "inc/Core/Common.h" #include "inc/Core/Common.h"
#include "inc/Helper/ArgumentsParser.h" #include "inc/Helper/VectorSetReader.h"
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -16,21 +16,13 @@ namespace SPTAG ...@@ -16,21 +16,13 @@ namespace SPTAG
namespace IndexBuilder namespace IndexBuilder
{ {
class BuilderOptions : public Helper::ArgumentsParser class BuilderOptions : public Helper::ReaderOptions
{ {
public: public:
BuilderOptions(); BuilderOptions();
~BuilderOptions(); ~BuilderOptions();
std::uint32_t m_threadNum;
std::uint32_t m_dimension;
std::string m_vectorDelimiter;
SPTAG::VectorValueType m_inputValueType;
std::string m_inputFiles; std::string m_inputFiles;
std::string m_outputFolder; std::string m_outputFolder;
......
...@@ -7,6 +7,4 @@ ...@@ -7,6 +7,4 @@
<package id="boost_system-vc140" version="1.67.0.0" targetFramework="native" /> <package id="boost_system-vc140" version="1.67.0.0" targetFramework="native" />
<package id="boost_thread-vc140" version="1.67.0.0" targetFramework="native" /> <package id="boost_thread-vc140" version="1.67.0.0" targetFramework="native" />
<package id="boost_wserialization-vc140" version="1.67.0.0" targetFramework="native" /> <package id="boost_wserialization-vc140" version="1.67.0.0" targetFramework="native" />
<package id="tbb_oss" version="9.107.0.0" targetFramework="native" />
<package id="tbb_oss.redist" version="9.107.0.0" targetFramework="native" />
</packages> </packages>
\ No newline at end of file
...@@ -53,19 +53,19 @@ int main(int argc, char** argv) ...@@ -53,19 +53,19 @@ int main(int argc, char** argv)
for (const auto& indexRes : result.m_allIndexResults) for (const auto& indexRes : result.m_allIndexResults)
{ {
fprintf(stdout, "Index: %s\n", indexRes.m_indexName.c_str()); std::cout << "Index: " << indexRes.m_indexName << std::endl;
int idx = 0; int idx = 0;
for (const auto& res : indexRes.m_results) for (const auto& res : indexRes.m_results)
{ {
fprintf(stdout, "------------------\n"); std::cout << "------------------" << std::endl;
fprintf(stdout, "DocIndex: %d Distance: %f\n", res.VID, res.Dist); std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist;
if (indexRes.m_results.WithMeta()) if (indexRes.m_results.WithMeta())
{ {
const auto& metadata = indexRes.m_results.GetMetadata(idx); const auto& metadata = indexRes.m_results.GetMetadata(idx);
fprintf(stdout, " MetaData: %.*s\n", static_cast<int>(metadata.Length()), metadata.Data()); std::cout << " MetaData: " << std::string((char*)metadata.Data(), metadata.Length());
} }
std::cout << std::endl;
++idx; ++idx;
} }
} }
......
...@@ -13,22 +13,7 @@ namespace SPTAG ...@@ -13,22 +13,7 @@ namespace SPTAG
namespace BKT namespace BKT
{ {
template <typename T> template <typename T>
ErrorCode Index<T>::LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs) ErrorCode Index<T>::LoadConfig(Helper::IniReader& p_reader)
{
if (!m_pSamples.Load((char*)p_indexBlobs[0])) return ErrorCode::FailedParseValue;
if (!m_pTrees.LoadTrees((char*)p_indexBlobs[1])) return ErrorCode::FailedParseValue;
if (!m_pGraph.LoadGraphFromMemory((char*)p_indexBlobs[2])) return ErrorCode::FailedParseValue;
m_pMetadata = std::make_shared<MemMetadataSet>();
if (ErrorCode::Success != m_pMetadata->LoadMetadataFromMemory((char*)p_indexBlobs[3]))
return ErrorCode::FailedParseValue;
m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
m_workSpacePool->Init(m_iNumberOfThreads);
return ErrorCode::Success;
}
template <typename T>
ErrorCode Index<T>::LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader)
{ {
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ #define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
SetParameter(RepresentStr, \ SetParameter(RepresentStr, \
...@@ -38,34 +23,96 @@ namespace SPTAG ...@@ -38,34 +23,96 @@ namespace SPTAG
#include "inc/Core/BKT/ParameterDefinitionList.h" #include "inc/Core/BKT/ParameterDefinitionList.h"
#undef DefineBKTParameter #undef DefineBKTParameter
return ErrorCode::Success;
}
template <typename T>
ErrorCode Index<T>::LoadIndexDataFromMemory(const std::vector<ByteArray>& p_indexBlobs)
{
if (p_indexBlobs.size() < 3) return ErrorCode::LackOfInputs;
if (!m_pSamples.Load((char*)p_indexBlobs[0].Data())) return ErrorCode::FailedParseValue;
if (!m_pTrees.LoadTrees((char*)p_indexBlobs[1].Data())) return ErrorCode::FailedParseValue;
if (!m_pGraph.LoadGraph((char*)p_indexBlobs[2].Data())) return ErrorCode::FailedParseValue;
if (p_indexBlobs.size() > 3 && !m_deletedID.load((char*)p_indexBlobs[3].Data())) return ErrorCode::FailedParseValue;
m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
m_workSpacePool->Init(m_iNumberOfThreads);
return ErrorCode::Success;
}
template <typename T>
ErrorCode Index<T>::LoadIndexData(const std::string& p_folderPath)
{
if (!m_pSamples.Load(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; if (!m_pSamples.Load(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
if (!m_pTrees.LoadTrees(p_folderPath + m_sBKTFilename)) return ErrorCode::Fail; if (!m_pTrees.LoadTrees(p_folderPath + m_sBKTFilename)) return ErrorCode::Fail;
if (!m_pGraph.LoadGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail; if (!m_pGraph.LoadGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail;
if (!m_deletedID.load(p_folderPath + m_sDeleteDataPointsFilename)) return ErrorCode::Fail;
m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
m_workSpacePool->Init(m_iNumberOfThreads); m_workSpacePool->Init(m_iNumberOfThreads);
return ErrorCode::Success; return ErrorCode::Success;
} }
template <typename T>
ErrorCode Index<T>::SaveConfig(std::ostream& p_configOut) const
{
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
p_configOut << RepresentStr << "=" << GetParameter(RepresentStr) << std::endl;
#include "inc/Core/BKT/ParameterDefinitionList.h"
#undef DefineBKTParameter
p_configOut << std::endl;
return ErrorCode::Success;
}
template<typename T>
ErrorCode
Index<T>::SaveIndexData(const std::string& p_folderPath)
{
std::lock_guard<std::mutex> lock(m_dataAddLock);
std::shared_lock<std::shared_timed_mutex> sharedlock(m_deletedID.getLock());
if (!m_pSamples.Save(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
if (!m_pTrees.SaveTrees(p_folderPath + m_sBKTFilename)) return ErrorCode::Fail;
if (!m_pGraph.SaveGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail;
if (!m_deletedID.save(p_folderPath + m_sDeleteDataPointsFilename)) return ErrorCode::Fail;
return ErrorCode::Success;
}
template<typename T>
ErrorCode Index<T>::SaveIndexData(const std::vector<std::ostream*>& p_indexStreams)
{
if (p_indexStreams.size() < 4) return ErrorCode::LackOfInputs;
std::lock_guard<std::mutex> lock(m_dataAddLock);
std::shared_lock<std::shared_timed_mutex> sharedlock(m_deletedID.getLock());
if (!m_pSamples.Save(*p_indexStreams[0])) return ErrorCode::Fail;
if (!m_pTrees.SaveTrees(*p_indexStreams[1])) return ErrorCode::Fail;
if (!m_pGraph.SaveGraph(*p_indexStreams[2])) return ErrorCode::Fail;
if (!m_deletedID.save(*p_indexStreams[3])) return ErrorCode::Fail;
return ErrorCode::Success;
}
#pragma region K-NN search #pragma region K-NN search
#define Search(CheckDeleted1) \ #define Search(CheckDeleted1) \
m_pTrees.InitSearchTrees(this, p_query, p_space); \ m_pTrees.InitSearchTrees(this, p_query, p_space); \
const int checkPos = m_pGraph.m_iNeighborhoodSize - 1; \ const DimensionType checkPos = m_pGraph.m_iNeighborhoodSize - 1; \
while (!p_space.m_SPTQueue.empty()) { \ while (!p_space.m_SPTQueue.empty()) { \
m_pTrees.SearchTrees(this, p_query, p_space, m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); \ m_pTrees.SearchTrees(this, p_query, p_space, m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); \
while (!p_space.m_NGQueue.empty()) { \ while (!p_space.m_NGQueue.empty()) { \
COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); \ COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); \
const int *node = m_pGraph[gnode.node]; \ const SizeType *node = m_pGraph[gnode.node]; \
_mm_prefetch((const char *)node, _MM_HINT_T0); \ _mm_prefetch((const char *)node, _MM_HINT_T0); \
CheckDeleted1 { \ CheckDeleted1 { \
if (p_query.AddPoint(gnode.node, gnode.distance)) { \ if (p_query.AddPoint(gnode.node, gnode.distance)) { \
p_space.m_iNumOfContinuousNoBetterPropagation = 0; \ p_space.m_iNumOfContinuousNoBetterPropagation = 0; \
int checkNode = node[checkPos]; \ SizeType checkNode = node[checkPos]; \
if (checkNode < -1) { \ if (checkNode < -1) { \
const COMMON::BKTNode& tnode = m_pTrees[-2 - checkNode]; \ const COMMON::BKTNode& tnode = m_pTrees[-2 - checkNode]; \
for (int i = -tnode.childStart; i < tnode.childEnd; i++) { \ for (SizeType i = -tnode.childStart; i < tnode.childEnd; i++) { \
if (!p_query.AddPoint(m_pTrees[i].centerid, gnode.distance)) break; \ if (!p_query.AddPoint(m_pTrees[i].centerid, gnode.distance)) break; \
} \ } \
} \ } \
...@@ -77,11 +124,11 @@ namespace SPTAG ...@@ -77,11 +124,11 @@ namespace SPTAG
} \ } \
} \ } \
} \ } \
for (int i = 0; i <= checkPos; i++) { \ for (DimensionType i = 0; i <= checkPos; i++) { \
_mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \
} \ } \
for (int i = 0; i <= checkPos; i++) { \ for (DimensionType i = 0; i <= checkPos; i++) { \
int nn_index = node[i]; \ SizeType nn_index = node[i]; \
if (nn_index < 0) break; \ if (nn_index < 0) break; \
if (p_space.CheckAndSet(nn_index)) continue; \ if (p_space.CheckAndSet(nn_index)) continue; \
float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], GetFeatureDim()); \ float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], GetFeatureDim()); \
...@@ -96,9 +143,9 @@ namespace SPTAG ...@@ -96,9 +143,9 @@ namespace SPTAG
p_query.SortResult(); \ p_query.SortResult(); \
template <typename T> template <typename T>
void Index<T>::SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const void Index<T>::SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet<SizeType> &p_deleted) const
{ {
Search(if (p_deleted.find(gnode.node) == p_deleted.end())) Search(if (!p_deleted.contains(gnode.node)))
} }
template <typename T> template <typename T>
...@@ -125,7 +172,7 @@ namespace SPTAG ...@@ -125,7 +172,7 @@ namespace SPTAG
{ {
for (int i = 0; i < p_query.GetResultNum(); ++i) for (int i = 0; i < p_query.GetResultNum(); ++i)
{ {
int result = p_query.GetResult(i)->VID; SizeType result = p_query.GetResult(i)->VID;
p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result)); p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result));
} }
} }
...@@ -134,7 +181,7 @@ namespace SPTAG ...@@ -134,7 +181,7 @@ namespace SPTAG
#pragma endregion #pragma endregion
template <typename T> template <typename T>
ErrorCode Index<T>::BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) ErrorCode Index<T>::BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension)
{ {
omp_set_num_threads(m_iNumberOfThreads); omp_set_num_threads(m_iNumberOfThreads);
...@@ -144,14 +191,14 @@ namespace SPTAG ...@@ -144,14 +191,14 @@ namespace SPTAG
{ {
int base = COMMON::Utils::GetBase<T>(); int base = COMMON::Utils::GetBase<T>();
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < GetNumSamples(); i++) { for (SizeType i = 0; i < GetNumSamples(); i++) {
COMMON::Utils::Normalize(m_pSamples[i], GetFeatureDim(), base); COMMON::Utils::Normalize(m_pSamples[i], GetFeatureDim(), base);
} }
} }
m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
m_workSpacePool->Init(m_iNumberOfThreads); m_workSpacePool->Init(m_iNumberOfThreads);
m_pTrees.BuildTrees<T>(this); m_pTrees.BuildTrees<T>(this);
m_pGraph.BuildGraph<T>(this, &(m_pTrees.GetSampleMap())); m_pGraph.BuildGraph<T>(this, &(m_pTrees.GetSampleMap()));
...@@ -159,31 +206,22 @@ namespace SPTAG ...@@ -159,31 +206,22 @@ namespace SPTAG
} }
template <typename T> template <typename T>
ErrorCode Index<T>::RefineIndex(const std::string& p_folderPath) ErrorCode Index<T>::RefineIndex(const std::vector<std::ostream*>& p_indexStreams)
{ {
std::string folderPath(p_folderPath); std::lock_guard<std::mutex> lock(m_dataAddLock);
if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) std::shared_lock<std::shared_timed_mutex> sharedlock(m_deletedID.getLock());
{
folderPath += FolderSep;
}
if (!direxists(folderPath.c_str()))
{
mkdir(folderPath.c_str());
}
std::lock_guard<std::mutex> lock(m_dataLock); SizeType newR = GetNumSamples();
int newR = GetNumSamples();
std::vector<int> indices; std::vector<SizeType> indices;
std::vector<int> reverseIndices(newR); std::vector<SizeType> reverseIndices(newR);
for (int i = 0; i < newR; i++) { for (SizeType i = 0; i < newR; i++) {
if (m_deletedID.find(i) == m_deletedID.end()) { if (!m_deletedID.contains(i)) {
indices.push_back(i); indices.push_back(i);
reverseIndices[i] = i; reverseIndices[i] = i;
} }
else { else {
while (m_deletedID.find(newR - 1) != m_deletedID.end() && newR > i) newR--; while (m_deletedID.contains(newR - 1) && newR > i) newR--;
if (newR == i) break; if (newR == i) break;
indices.push_back(newR - 1); indices.push_back(newR - 1);
reverseIndices[newR - 1] = i; reverseIndices[newR - 1] = i;
...@@ -193,33 +231,72 @@ namespace SPTAG ...@@ -193,33 +231,72 @@ namespace SPTAG
std::cout << "Refine... from " << GetNumSamples() << "->" << newR << std::endl; std::cout << "Refine... from " << GetNumSamples() << "->" << newR << std::endl;
if (false == m_pSamples.Refine(indices, folderPath + m_sDataPointsFilename)) return ErrorCode::FailedCreateFile; if (false == m_pSamples.Refine(indices, *p_indexStreams[0])) return ErrorCode::Fail;
if (nullptr != m_pMetadata && ErrorCode::Success != m_pMetadata->RefineMetadata(indices, folderPath)) return ErrorCode::FailedCreateFile; if (nullptr != m_pMetadata && (p_indexStreams.size() < 6 || ErrorCode::Success != m_pMetadata->RefineMetadata(indices, *p_indexStreams[4], *p_indexStreams[5]))) return ErrorCode::Fail;
COMMON::BKTree newTrees(m_pTrees); COMMON::BKTree newTrees(m_pTrees);
newTrees.BuildTrees<T>(this, &indices); newTrees.BuildTrees<T>(this, &indices);
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < newTrees.size(); i++) { for (SizeType i = 0; i < newTrees.size(); i++) {
newTrees[i].centerid = reverseIndices[newTrees[i].centerid]; newTrees[i].centerid = reverseIndices[newTrees[i].centerid];
} }
newTrees.SaveTrees(folderPath + m_sBKTFilename); newTrees.SaveTrees(*p_indexStreams[1]);
m_pGraph.RefineGraph<T>(this, indices, reverseIndices, *p_indexStreams[2], &(newTrees.GetSampleMap()));
m_pGraph.RefineGraph<T>(this, indices, reverseIndices, folderPath + m_sGraphFilename, Helper::Concurrent::ConcurrentSet<SizeType> newDeletedID;
&(newTrees.GetSampleMap())); newDeletedID.save(*p_indexStreams[3]);
return ErrorCode::Success; return ErrorCode::Success;
} }
template <typename T> template <typename T>
ErrorCode Index<T>::DeleteIndex(const void* p_vectors, int p_vectorNum) { ErrorCode Index<T>::RefineIndex(const std::string& p_folderPath)
{
std::string folderPath(p_folderPath);
if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep)
{
folderPath += FolderSep;
}
if (!direxists(folderPath.c_str()))
{
mkdir(folderPath.c_str());
}
std::vector<std::ostream*> streams;
streams.push_back(new std::ofstream(folderPath + m_sDataPointsFilename, std::ios::binary));
streams.push_back(new std::ofstream(folderPath + m_sBKTFilename, std::ios::binary));
streams.push_back(new std::ofstream(folderPath + m_sGraphFilename, std::ios::binary));
streams.push_back(new std::ofstream(folderPath + m_sDeleteDataPointsFilename, std::ios::binary));
if (nullptr != m_pMetadata)
{
streams.push_back(new std::ofstream(folderPath + m_sMetadataFile, std::ios::binary));
streams.push_back(new std::ofstream(folderPath + m_sMetadataIndexFile, std::ios::binary));
}
for (size_t i = 0; i < streams.size(); i++)
if (!(((std::ofstream*)streams[i])->is_open())) return ErrorCode::FailedCreateFile;
ErrorCode ret = RefineIndex(streams);
for (size_t i = 0; i < streams.size(); i++)
{
((std::ofstream*)streams[i])->close();
delete streams[i];
}
return ret;
}
template <typename T>
ErrorCode Index<T>::DeleteIndex(const void* p_vectors, SizeType p_vectorNum) {
const T* ptr_v = (const T*)p_vectors; const T* ptr_v = (const T*)p_vectors;
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < p_vectorNum; i++) { for (SizeType i = 0; i < p_vectorNum; i++) {
COMMON::QueryResultSet<T> query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF); COMMON::QueryResultSet<T> query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF);
SearchIndex(query); SearchIndex(query);
for (int i = 0; i < m_pGraph.m_iCEF; i++) { for (int i = 0; i < m_pGraph.m_iCEF; i++) {
if (query.GetResult(i)->Dist < 1e-6) { if (query.GetResult(i)->Dist < 1e-6) {
std::lock_guard<std::mutex> lock(m_dataLock);
m_deletedID.insert(query.GetResult(i)->VID); m_deletedID.insert(query.GetResult(i)->VID);
} }
} }
...@@ -228,40 +305,43 @@ namespace SPTAG ...@@ -228,40 +305,43 @@ namespace SPTAG
} }
template <typename T> template <typename T>
ErrorCode Index<T>::AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) ErrorCode Index<T>::DeleteIndex(const SizeType& p_id) {
m_deletedID.insert(p_id);
return ErrorCode::Success;
}
template <typename T>
ErrorCode Index<T>::AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start)
{ {
int begin, end; SizeType begin, end;
{ {
std::lock_guard<std::mutex> lock(m_dataLock); std::lock_guard<std::mutex> lock(m_dataAddLock);
if (GetNumSamples() == 0)
return BuildIndex(p_vectors, p_vectorNum, p_dimension);
if (p_dimension != GetFeatureDim())
return ErrorCode::FailedParseValue;
begin = GetNumSamples(); begin = GetNumSamples();
end = GetNumSamples() + p_vectorNum; end = GetNumSamples() + p_vectorNum;
m_pSamples.AddBatch((const T*)p_vectors, p_vectorNum); if (p_start != nullptr) *p_start = begin;
m_pGraph.AddBatch(p_vectorNum);
if (begin == 0) return BuildIndex(p_vectors, p_vectorNum, p_dimension);
if (m_pSamples.R() != end || m_pGraph.R() != end) { if (p_dimension != GetFeatureDim()) return ErrorCode::FailedParseValue;
if (m_pSamples.AddBatch((const T*)p_vectors, p_vectorNum) != ErrorCode::Success || m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success) {
std::cout << "Memory Error: Cannot alloc space for vectors" << std::endl; std::cout << "Memory Error: Cannot alloc space for vectors" << std::endl;
m_pSamples.SetR(begin); m_pSamples.SetR(begin);
m_pGraph.SetR(begin); m_pGraph.SetR(begin);
return ErrorCode::Fail; return ErrorCode::MemoryOverFlow;
} }
if (DistCalcMethod::Cosine == m_iDistCalcMethod) if (DistCalcMethod::Cosine == m_iDistCalcMethod)
{ {
int base = COMMON::Utils::GetBase<T>(); int base = COMMON::Utils::GetBase<T>();
for (int i = begin; i < end; i++) { for (SizeType i = begin; i < end; i++) {
COMMON::Utils::Normalize((T*)m_pSamples[i], GetFeatureDim(), base); COMMON::Utils::Normalize((T*)m_pSamples[i], GetFeatureDim(), base);
} }
} }
} }
for (int node = begin; node < end; node++) for (SizeType node = begin; node < end; node++)
{ {
m_pGraph.RefineNode<T>(this, node, true); m_pGraph.RefineNode<T>(this, node, true);
} }
...@@ -269,47 +349,6 @@ namespace SPTAG ...@@ -269,47 +349,6 @@ namespace SPTAG
return ErrorCode::Success; return ErrorCode::Success;
} }
template<typename T>
ErrorCode
Index<T>::SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t> &p_indexBlobsLen)
{
p_indexBlobs.resize(4);
p_indexBlobsLen.resize(4);
if (!m_pSamples.Save(&p_indexBlobs[0], p_indexBlobsLen[0])) return ErrorCode::Fail;
if (!m_pTrees.SaveTrees(&p_indexBlobs[1], p_indexBlobsLen[1])) return ErrorCode::Fail;
if (!m_pGraph.SaveGraphToMemory(&p_indexBlobs[2], p_indexBlobsLen[2])) return ErrorCode::Fail;
if (ErrorCode::Success != m_pMetadata->SaveMetadataToMemory(&p_indexBlobs[3], p_indexBlobsLen[3]))
return ErrorCode::Fail;
return ErrorCode::Success;
}
template<typename T>
ErrorCode
Index<T>::SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout)
{
m_sDataPointsFilename = "vectors.bin";
m_sBKTFilename = "tree.bin";
m_sGraphFilename = "graph.bin";
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
p_configout << RepresentStr << "=" << GetParameter(RepresentStr) << std::endl;
#include "inc/Core/BKT/ParameterDefinitionList.h"
#undef DefineBKTParameter
p_configout << std::endl;
if (m_deletedID.size() > 0) {
RefineIndex(p_folderPath);
}
else {
if (!m_pSamples.Save(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
if (!m_pTrees.SaveTrees(p_folderPath + m_sBKTFilename)) return ErrorCode::Fail;
if (!m_pGraph.SaveGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail;
}
return ErrorCode::Success;
}
template <typename T> template <typename T>
ErrorCode ErrorCode
Index<T>::SetParameter(const char* p_param, const char* p_value) Index<T>::SetParameter(const char* p_param, const char* p_value)
......
...@@ -7,7 +7,7 @@ using namespace SPTAG; ...@@ -7,7 +7,7 @@ using namespace SPTAG;
using namespace SPTAG::COMMON; using namespace SPTAG::COMMON;
WorkSpacePool::WorkSpacePool(int p_maxCheck, int p_vectorCount) WorkSpacePool::WorkSpacePool(int p_maxCheck, SizeType p_vectorCount)
: m_maxCheck(p_maxCheck), : m_maxCheck(p_maxCheck),
m_vectorCount(p_vectorCount) m_vectorCount(p_vectorCount)
{ {
......
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "inc/Core/CommonDataStructure.h"
using namespace SPTAG;
const ByteArray ByteArray::c_empty;
ByteArray::ByteArray()
: m_data(nullptr),
m_length(0)
{
}
ByteArray::ByteArray(ByteArray&& p_right)
: m_data(p_right.m_data),
m_length(p_right.m_length),
m_dataHolder(std::move(p_right.m_dataHolder))
{
}
ByteArray::ByteArray(std::uint8_t* p_array, std::size_t p_length, bool p_transferOnwership)
: m_data(p_array),
m_length(p_length)
{
if (p_transferOnwership)
{
m_dataHolder.reset(m_data, std::default_delete<std::uint8_t[]>());
}
}
ByteArray::ByteArray(std::uint8_t* p_array, std::size_t p_length, std::shared_ptr<std::uint8_t> p_dataHolder)
: m_data(p_array),
m_length(p_length),
m_dataHolder(std::move(p_dataHolder))
{
}
ByteArray::ByteArray(const ByteArray& p_right)
: m_data(p_right.m_data),
m_length(p_right.m_length),
m_dataHolder(p_right.m_dataHolder)
{
}
ByteArray&
ByteArray::operator= (const ByteArray& p_right)
{
m_data = p_right.m_data;
m_length = p_right.m_length;
m_dataHolder = p_right.m_dataHolder;
return *this;
}
ByteArray&
ByteArray::operator= (ByteArray&& p_right)
{
m_data = p_right.m_data;
m_length = p_right.m_length;
m_dataHolder = std::move(p_right.m_dataHolder);
return *this;
}
ByteArray::~ByteArray()
{
}
ByteArray
ByteArray::Alloc(std::size_t p_length)
{
ByteArray byteArray;
if (0 == p_length)
{
return byteArray;
}
byteArray.m_dataHolder.reset(new std::uint8_t[p_length],
std::default_delete<std::uint8_t[]>());
byteArray.m_length = p_length;
byteArray.m_data = byteArray.m_dataHolder.get();
return byteArray;
}
std::uint8_t*
ByteArray::Data() const
{
return m_data;
}
std::size_t
ByteArray::Length() const
{
return m_length;
}
void
ByteArray::SetData(std::uint8_t* p_array, std::size_t p_length)
{
m_data = p_array;
m_length = p_length;
}
std::shared_ptr<std::uint8_t>
ByteArray::DataHolder() const
{
return m_dataHolder;
}
void
ByteArray::Clear()
{
m_data = nullptr;
m_dataHolder.reset();
m_length = 0;
}
\ No newline at end of file
...@@ -19,7 +19,7 @@ VectorSet::~VectorSet() ...@@ -19,7 +19,7 @@ VectorSet::~VectorSet()
BasicVectorSet::BasicVectorSet(const ByteArray& p_bytesArray, BasicVectorSet::BasicVectorSet(const ByteArray& p_bytesArray,
VectorValueType p_valueType, VectorValueType p_valueType,
SizeType p_dimension, DimensionType p_dimension,
SizeType p_vectorCount) SizeType p_vectorCount)
: m_data(p_bytesArray), : m_data(p_bytesArray),
m_valueType(p_valueType), m_valueType(p_valueType),
...@@ -43,15 +43,14 @@ BasicVectorSet::GetValueType() const ...@@ -43,15 +43,14 @@ BasicVectorSet::GetValueType() const
void* void*
BasicVectorSet::GetVector(IndexType p_vectorID) const BasicVectorSet::GetVector(SizeType p_vectorID) const
{ {
if (p_vectorID < 0 || static_cast<SizeType>(p_vectorID) >= m_vectorCount) if (p_vectorID < 0 || p_vectorID >= m_vectorCount)
{ {
return nullptr; return nullptr;
} }
SizeType offset = static_cast<SizeType>(p_vectorID) * m_perVectorDataSize; return reinterpret_cast<void*>(m_data.Data() + ((size_t)p_vectorID) * m_perVectorDataSize);
return reinterpret_cast<void*>(m_data.Data() + offset);
} }
...@@ -61,7 +60,7 @@ BasicVectorSet::GetData() const ...@@ -61,7 +60,7 @@ BasicVectorSet::GetData() const
return reinterpret_cast<void*>(m_data.Data()); return reinterpret_cast<void*>(m_data.Data());
} }
SizeType DimensionType
BasicVectorSet::Dimension() const BasicVectorSet::Dimension() const
{ {
return m_dimension; return m_dimension;
...@@ -88,8 +87,8 @@ BasicVectorSet::Save(const std::string& p_vectorFile) const ...@@ -88,8 +87,8 @@ BasicVectorSet::Save(const std::string& p_vectorFile) const
FILE * fp = fopen(p_vectorFile.c_str(), "wb"); FILE * fp = fopen(p_vectorFile.c_str(), "wb");
if (fp == NULL) return ErrorCode::FailedOpenFile; if (fp == NULL) return ErrorCode::FailedOpenFile;
fwrite(&m_vectorCount, sizeof(int), 1, fp); fwrite(&m_vectorCount, sizeof(SizeType), 1, fp);
fwrite(&m_dimension, sizeof(int), 1, fp); fwrite(&m_dimension, sizeof(DimensionType), 1, fp);
fwrite((const void*)(m_data.Data()), m_data.Length(), 1, fp); fwrite((const void*)(m_data.Data()), m_data.Length(), 1, fp);
fclose(fp); fclose(fp);
......
...@@ -25,15 +25,8 @@ IniReader::~IniReader() ...@@ -25,15 +25,8 @@ IniReader::~IniReader()
} }
ErrorCode ErrorCode IniReader::LoadIni(std::istream& p_input)
IniReader::LoadIniFile(const std::string& p_iniFilePath)
{ {
std::ifstream input(p_iniFilePath);
if (!input.is_open())
{
return ErrorCode::FailedOpenFile;
}
const std::size_t c_bufferSize = 1 << 16; const std::size_t c_bufferSize = 1 << 16;
std::unique_ptr<char[]> line(new char[c_bufferSize]); std::unique_ptr<char[]> line(new char[c_bufferSize]);
...@@ -51,9 +44,9 @@ IniReader::LoadIniFile(const std::string& p_iniFilePath) ...@@ -51,9 +44,9 @@ IniReader::LoadIniFile(const std::string& p_iniFilePath)
return std::isspace(p_ch) != 0; return std::isspace(p_ch) != 0;
}; };
while (!input.eof()) while (!p_input.eof())
{ {
if (!input.getline(line.get(), c_bufferSize)) if (!p_input.getline(line.get(), c_bufferSize))
{ {
break; break;
} }
...@@ -141,11 +134,21 @@ IniReader::LoadIniFile(const std::string& p_iniFilePath) ...@@ -141,11 +134,21 @@ IniReader::LoadIniFile(const std::string& p_iniFilePath)
} }
} }
} }
return ErrorCode::Success; return ErrorCode::Success;
} }
ErrorCode
IniReader::LoadIniFile(const std::string& p_iniFilePath)
{
std::ifstream input(p_iniFilePath);
if (!input.is_open()) return ErrorCode::FailedOpenFile;
ErrorCode ret = LoadIni(input);
input.close();
return ret;
}
bool bool
IniReader::DoesSectionExist(const std::string& p_section) const IniReader::DoesSectionExist(const std::string& p_section) const
{ {
......
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "inc/Helper/VectorSetReader.h"
#include "inc/Helper/VectorSetReaders/DefaultReader.h"
using namespace SPTAG;
using namespace SPTAG::Helper;
ReaderOptions::ReaderOptions(VectorValueType p_valueType, DimensionType p_dimension, std::string p_vectorDelimiter, std::uint32_t p_threadNum)
: m_threadNum(p_threadNum), m_dimension(p_dimension), m_vectorDelimiter(p_vectorDelimiter), m_inputValueType(p_valueType)
{
AddOptionalOption(m_threadNum, "-t", "--thread", "Thread Number.");
AddOptionalOption(m_vectorDelimiter, "", "--delimiter", "Vector delimiter.");
AddRequiredOption(m_dimension, "-d", "--dimension", "Dimension of vector.");
AddRequiredOption(m_inputValueType, "-v", "--vectortype", "Input vector data type. Default is float.");
}
ReaderOptions::~ReaderOptions()
{
}
VectorSetReader::VectorSetReader(std::shared_ptr<ReaderOptions> p_options)
: m_options(p_options)
{
}
VectorSetReader:: ~VectorSetReader()
{
}
std::shared_ptr<VectorSetReader>
VectorSetReader::CreateInstance(std::shared_ptr<ReaderOptions> p_options)
{
return std::shared_ptr<VectorSetReader>(new DefaultReader(std::move(p_options)));
}
...@@ -6,7 +6,7 @@ COPY AnnService ./AnnService/ ...@@ -6,7 +6,7 @@ COPY AnnService ./AnnService/
COPY Test ./Test/ COPY Test ./Test/
COPY Wrappers ./Wrappers/ COPY Wrappers ./Wrappers/
RUN apt-get update && apt-get -y install wget build-essential libtbb-dev \ RUN apt-get update && apt-get -y install wget build-essential \
# remove the following if you don't want to build the wrappers # remove the following if you don't want to build the wrappers
openjdk-8-jdk python3-pip swig openjdk-8-jdk python3-pip swig
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册