提交 e704c099 编写于 作者: Y Yu Kun

modify format


Former-commit-id: 846c7b569eb86267e0c549586f4923adafaef456
......@@ -8,3 +8,4 @@ base.info
output.info
output_new.info
server.info
thirdparty/knowhere/
......@@ -7,6 +7,7 @@ Please mark all change in change log and use the ticket from JIRA.
## Bug
## Improvement
- MS-327 - Clean code for milvus
## New Feature
......
......@@ -91,10 +91,6 @@ endif()
if(CMAKE_BUILD_TYPE STREQUAL "Release")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp")
if (CMAKE_LICENSE_CHECK STREQUAL "ON")
set(ENABLE_LICENSE "ON")
add_definitions("-DENABLE_LICENSE")
endif ()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp")
endif()
......
......@@ -24,7 +24,6 @@ set(db_scheduler_files
${scheduler_context_files}
${scheduler_task_files}
)
set(license_check_files
license/LicenseLibrary.cpp
license/LicenseCheck.cpp
......@@ -64,9 +63,6 @@ set(engine_files
${knowhere_files}
)
set(get_sys_info_files
license/GetSysInfo.cpp)
set(s3_client_files
storage/s3/S3ClientWrapper.cpp
storage/s3/S3ClientWrapper.h)
......@@ -163,29 +159,12 @@ if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
)
endif ()
if (ENABLE_LICENSE STREQUAL "ON")
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs")
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
set(license_libs
nvidia-ml
crypto
cudart
cublas
)
endif ()
cuda_add_library(milvus_engine STATIC ${engine_files})
target_link_libraries(milvus_engine ${engine_libs} ${third_party_libs})
add_library(metrics STATIC ${metrics_files})
if (ENABLE_LICENSE STREQUAL "ON")
add_library(license_check STATIC ${license_check_files})
target_link_libraries(license_check ${license_libs} ${third_party_libs})
endif ()
set(metrics_lib
easyloggingpp
yaml-cpp
......@@ -234,8 +213,13 @@ else()
)
endif()
if (ENABLE_LICENSE STREQUAL "ON")
add_executable(get_sys_info ${get_sys_info_files})
add_executable(license_generator ${license_generator_files})
target_link_libraries(get_sys_info ${license_libs} license_check ${third_party_libs})
target_link_libraries(license_generator ${license_libs} ${third_party_libs})
if(MILVUS_WITH_THRIFT STREQUAL "ON")
target_link_libraries(milvus_thrift_server ${server_libs} license_check ${knowhere_libs} ${third_party_libs})
else()
......@@ -250,17 +234,6 @@ else ()
endif()
if (ENABLE_LICENSE STREQUAL "ON")
add_executable(get_sys_info ${get_sys_info_files})
add_executable(license_generator ${license_generator_files})
target_link_libraries(get_sys_info ${license_libs} license_check ${third_party_libs})
target_link_libraries(license_generator ${license_libs} ${third_party_libs})
install(TARGETS get_sys_info DESTINATION bin)
install(TARGETS license_generator DESTINATION bin)
endif ()
if (MILVUS_WITH_THRIFT STREQUAL "ON")
install(TARGETS milvus_thrift_server DESTINATION bin)
else()
......
......@@ -6,7 +6,6 @@
#pragma once
#include "DB.h"
#include "MemManager.h"
#include "Types.h"
#include "utils/ThreadPool.h"
#include "MemManagerAbstract.h"
......
......@@ -6,7 +6,6 @@
#include "Factories.h"
#include "DBImpl.h"
#include "MemManager.h"
#include "NewMemManager.h"
#include "Exception.h"
......@@ -103,16 +102,6 @@ DB* DBFactory::Build(const Options& options) {
MemManagerAbstractPtr MemManagerFactory::Build(const std::shared_ptr<meta::Meta>& meta,
const Options& options) {
if (const char* env = getenv("MILVUS_USE_OLD_MEM_MANAGER")) {
std::string env_str = env;
std::transform(env_str.begin(), env_str.end(), env_str.begin(), ::toupper);
if (env_str == "ON") {
return std::make_shared<MemManager>(meta, options);
}
else {
return std::make_shared<NewMemManager>(meta, options);
}
}
return std::make_shared<NewMemManager>(meta, options);
}
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#if 0
#include "FaissExecutionEngine.h"
#include "Log.h"
#include "utils/CommonUtil.h"
#include <faiss/AutoTune.h>
#include <faiss/MetaIndexes.h>
#include <faiss/IndexFlat.h>
#include <faiss/index_io.h>
#include <wrapper/Index.h>
#include <wrapper/IndexBuilder.h>
#include <cache/CpuCacheMgr.h>
#include "faiss/IndexIVF.h"
#include "metrics/Metrics.h"
namespace zilliz {
namespace milvus {
namespace engine {
namespace {
std::string GetMetricType() {
server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
}
}
std::string IndexStatsHelper::ToString(const std::string &prefix) const {
return "";
}
void IndexStatsHelper::Reset() const {
faiss::indexIVF_stats.reset();
}
std::string FaissIndexIVFStatsHelper::ToString(const std::string &prefix) const {
std::stringstream ss;
ss << prefix;
ss << identifier_ << ":";
ss << " NQ=" << faiss::indexIVF_stats.nq;
ss << " NL=" << faiss::indexIVF_stats.nlist;
ss << " ND=" << faiss::indexIVF_stats.ndis;
ss << " NH=" << faiss::indexIVF_stats.nheap_updates;
ss << " Q=" << faiss::indexIVF_stats.quantization_time;
ss << " S=" << faiss::indexIVF_stats.search_time;
return ss.str();
}
FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension,
const std::string &location,
const std::string &build_index_type,
const std::string &raw_index_type)
: location_(location),
build_index_type_(build_index_type),
raw_index_type_(raw_index_type) {
std::string metric_type = GetMetricType();
faiss::MetricType faiss_metric_type = (metric_type == "L2") ? faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
pIndex_.reset(faiss::index_factory(dimension, raw_index_type.c_str(), faiss_metric_type));
}
FaissExecutionEngine::FaissExecutionEngine(std::shared_ptr<faiss::Index> index,
const std::string &location,
const std::string &build_index_type,
const std::string &raw_index_type)
: pIndex_(index),
location_(location),
build_index_type_(build_index_type),
raw_index_type_(raw_index_type) {
}
Status FaissExecutionEngine::AddWithIds(long n, const float *xdata, const long *xids) {
pIndex_->add_with_ids(n, xdata, xids);
return Status::OK();
}
size_t FaissExecutionEngine::Count() const {
return (size_t) (pIndex_->ntotal);
}
size_t FaissExecutionEngine::Size() const {
return (size_t) (Count() * pIndex_->d) * sizeof(float);
}
size_t FaissExecutionEngine::Dimension() const {
return pIndex_->d;
}
size_t FaissExecutionEngine::PhysicalSize() const {
return server::CommonUtil::GetFileSize(location_);
}
Status FaissExecutionEngine::Serialize() {
write_index(pIndex_.get(), location_.c_str());
return Status::OK();
}
Status FaissExecutionEngine::Load(bool to_cache) {
auto index = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location_);
bool already_in_cache = (index != nullptr);
auto start_time = METRICS_NOW_TIME;
if (!index) {
index = read_index(location_);
ENGINE_LOG_DEBUG << "Disk io from: " << location_;
}
pIndex_ = index->data();
if (!already_in_cache && to_cache) {
Cache();
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
server::Metrics::GetInstance().FaissDiskLoadDurationSecondsHistogramObserve(total_time);
double total_size = (pIndex_->d) * (pIndex_->ntotal) * 4;
server::Metrics::GetInstance().FaissDiskLoadSizeBytesHistogramObserve(total_size);
// server::Metrics::GetInstance().FaissDiskLoadIOSpeedHistogramObserve(total_size/double(total_time));
server::Metrics::GetInstance().FaissDiskLoadIOSpeedGaugeSet(total_size / double(total_time));
}
return Status::OK();
}
Status FaissExecutionEngine::Merge(const std::string &location) {
if (location == location_) {
return Status::Error("Cannot Merge Self");
}
ENGINE_LOG_DEBUG << "Merge raw file: " << location << " to: " << location_;
auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location);
if (!to_merge) {
to_merge = read_index(location);
}
auto file_index = dynamic_cast<faiss::IndexIDMap *>(to_merge->data().get());
pIndex_->add_with_ids(file_index->ntotal, dynamic_cast<faiss::IndexFlat *>(file_index->index)->xb.data(),
file_index->id_map.data());
return Status::OK();
}
ExecutionEnginePtr
FaissExecutionEngine::BuildIndex(const std::string &location) {
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
auto opd = std::make_shared<Operand>();
opd->d = pIndex_->d;
opd->index_type = build_index_type_;
opd->metric_type = GetMetricType();
IndexBuilderPtr pBuilder = GetIndexBuilder(opd);
auto from_index = dynamic_cast<faiss::IndexIDMap *>(pIndex_.get());
auto index = pBuilder->build_all(from_index->ntotal,
dynamic_cast<faiss::IndexFlat *>(from_index->index)->xb.data(),
from_index->id_map.data());
ExecutionEnginePtr new_ee(new FaissExecutionEngine(index->data(), location, build_index_type_, raw_index_type_));
return new_ee;
}
Status FaissExecutionEngine::Search(long n,
const float *data,
long k,
float *distances,
long *labels) const {
auto start_time = METRICS_NOW_TIME;
std::shared_ptr<faiss::IndexIVF> ivf_index = std::dynamic_pointer_cast<faiss::IndexIVF>(pIndex_);
if (ivf_index) {
std::string stats_prefix = "K=" + std::to_string(k) + ":";
ENGINE_LOG_DEBUG << "Searching index type: " << build_index_type_ << " nProbe: " << nprobe_;
ivf_index->nprobe = nprobe_;
ivf_stats_helper_.Reset();
ivf_index->search(n, data, k, distances, labels);
ENGINE_LOG_INFO << ivf_stats_helper_.ToString(stats_prefix);
} else {
ENGINE_LOG_DEBUG << "Searching raw file";
pIndex_->search(n, data, k, distances, labels);
}
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
server::Metrics::GetInstance().QueryIndexTypePerSecondSet(build_index_type_, double(n) / double(total_time));
return Status::OK();
}
Status FaissExecutionEngine::Cache() {
auto index = std::make_shared<Index>(pIndex_);
cache::DataObjPtr data_obj = std::make_shared<cache::DataObj>(index, PhysicalSize());
zilliz::milvus::cache::CpuCacheMgr::GetInstance()->InsertItem(location_, data_obj);
return Status::OK();
}
Status FaissExecutionEngine::Init() {
if (build_index_type_ == BUILD_INDEX_TYPE_IVF ||
build_index_type_ == BUILD_INDEX_TYPE_IVFSQ8) {
using namespace zilliz::milvus::server;
ServerConfig &config = ServerConfig::GetInstance();
ConfigNode engine_config = config.GetConfig(CONFIG_ENGINE);
nprobe_ = engine_config.GetInt32Value(CONFIG_NPROBE, 1000);
nlist_ = engine_config.GetInt32Value(CONFIG_NLIST, 16384);
} else if (build_index_type_ == BUILD_INDEX_TYPE_IDMAP) { ;
} else {
return Status::Error("Wrong index type: ", build_index_type_);
}
return Status::OK();
}
} // namespace engine
} // namespace milvus
} // namespace zilliz
#endif
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#if 0
#include "ExecutionEngine.h"
#include "faiss/Index.h"
#include <memory>
#include <string>
namespace zilliz {
namespace milvus {
namespace engine {
const static std::string BUILD_INDEX_TYPE_IDMAP = "IDMap";
const static std::string BUILD_INDEX_TYPE_IVF = "IVF";
const static std::string BUILD_INDEX_TYPE_IVFSQ8 = "IVFSQ8";
class IndexStatsHelper {
public:
using Ptr = std::shared_ptr<IndexStatsHelper>;
virtual std::string ToString(const std::string &prefix = "") const;
virtual void Reset() const;
virtual ~IndexStatsHelper() {}
};
class FaissIndexIVFStatsHelper : public IndexStatsHelper {
public:
std::string ToString(const std::string &prefix = "") const override;
private:
const std::string identifier_ = BUILD_INDEX_TYPE_IVF;
};
class FaissExecutionEngine : public ExecutionEngine {
public:
FaissExecutionEngine(uint16_t dimension,
const std::string &location,
const std::string &build_index_type,
const std::string &raw_index_type);
FaissExecutionEngine(std::shared_ptr<faiss::Index> index,
const std::string &location,
const std::string &build_index_type,
const std::string &raw_index_type);
Status AddWithIds(long n, const float *xdata, const long *xids) override;
size_t Count() const override;
size_t Size() const override;
size_t Dimension() const override;
size_t PhysicalSize() const override;
Status Serialize() override;
Status Load(bool to_cache) override;
Status Merge(const std::string &location) override;
Status Search(long n,
const float *data,
long k,
float *distances,
long *labels) const override;
ExecutionEnginePtr BuildIndex(const std::string &) override;
Status Cache() override;
Status Init() override;
protected:
FaissIndexIVFStatsHelper ivf_stats_helper_;
std::shared_ptr<faiss::Index> pIndex_;
std::string location_;
std::string build_index_type_;
std::string raw_index_type_;
size_t nprobe_ = 0;
size_t nlist_ = 0;
};
} // namespace engine
} // namespace milvus
} // namespace zilliz
#endif
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "MemManager.h"
#include "Meta.h"
#include "MetaConsts.h"
#include "EngineFactory.h"
#include "metrics/Metrics.h"
#include "Log.h"
#include <iostream>
#include <sstream>
#include <thread>
#include <easylogging++.h>
namespace zilliz {
namespace milvus {
namespace engine {
MemVectors::MemVectors(const std::shared_ptr<meta::Meta> &meta_ptr,
const meta::TableFileSchema &schema, const Options &options)
: meta_(meta_ptr),
options_(options),
schema_(schema),
id_generator_(new SimpleIDGenerator()),
active_engine_(EngineFactory::Build(schema_.dimension_, schema_.location_, (EngineType) schema_.engine_type_)) {
}
Status MemVectors::Add(size_t n_, const float *vectors_, IDNumbers &vector_ids_) {
if (active_engine_ == nullptr) {
return Status::Error("index engine is null");
}
auto start_time = METRICS_NOW_TIME;
id_generator_->GetNextIDNumbers(n_, vector_ids_);
Status status = active_engine_->AddWithIds(n_, vectors_, vector_ids_.data());
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
server::Metrics::GetInstance().AddVectorsPerSecondGaugeSet(static_cast<int>(n_),
static_cast<int>(schema_.dimension_),
total_time);
return status;
}
size_t MemVectors::RowCount() const {
if (active_engine_ == nullptr) {
return 0;
}
return active_engine_->Count();
}
size_t MemVectors::Size() const {
if (active_engine_ == nullptr) {
return 0;
}
return active_engine_->Size();
}
Status MemVectors::Serialize(std::string &table_id) {
if (active_engine_ == nullptr) {
return Status::Error("index engine is null");
}
table_id = schema_.table_id_;
auto size = Size();
auto start_time = METRICS_NOW_TIME;
active_engine_->Serialize();
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
schema_.size_ = size;
server::Metrics::GetInstance().DiskStoreIOSpeedGaugeSet(size / total_time);
schema_.file_type_ = (size >= options_.index_trigger_size) ?
meta::TableFileSchema::TO_INDEX : meta::TableFileSchema::RAW;
auto status = meta_->UpdateTableFile(schema_);
ENGINE_LOG_DEBUG << "New " << ((schema_.file_type_ == meta::TableFileSchema::RAW) ? "raw" : "to_index")
<< " file " << schema_.file_id_ << " of size " << active_engine_->Size() << " bytes";
if(options_.insert_cache_immediately_) {
active_engine_->Cache();
}
return status;
}
MemVectors::~MemVectors() {
if (id_generator_ != nullptr) {
delete id_generator_;
id_generator_ = nullptr;
}
}
/*
* MemManager
*/
MemManager::MemVectorsPtr MemManager::GetMemByTable(
const std::string &table_id) {
auto memIt = mem_id_map_.find(table_id);
if (memIt != mem_id_map_.end()) {
return memIt->second;
}
meta::TableFileSchema table_file;
table_file.table_id_ = table_id;
auto status = meta_->CreateTableFile(table_file);
if (!status.ok()) {
return nullptr;
}
mem_id_map_[table_id] = MemVectorsPtr(new MemVectors(meta_, table_file, options_));
return mem_id_map_[table_id];
}
Status MemManager::InsertVectors(const std::string &table_id_,
size_t n_,
const float *vectors_,
IDNumbers &vector_ids_) {
std::unique_lock<std::mutex> lock(mutex_);
return InsertVectorsNoLock(table_id_, n_, vectors_, vector_ids_);
}
Status MemManager::InsertVectorsNoLock(const std::string &table_id,
size_t n,
const float *vectors,
IDNumbers &vector_ids) {
MemVectorsPtr mem = GetMemByTable(table_id);
if (mem == nullptr) {
return Status::NotFound("Group " + table_id + " not found!");
}
//makesure each file size less than index_trigger_size
if (mem->Size() > options_.index_trigger_size) {
std::unique_lock<std::mutex> lock(serialization_mtx_);
immu_mem_list_.push_back(mem);
mem_id_map_.erase(table_id);
return InsertVectorsNoLock(table_id, n, vectors, vector_ids);
} else {
return mem->Add(n, vectors, vector_ids);
}
}
Status MemManager::ToImmutable() {
std::unique_lock<std::mutex> lock(mutex_);
MemIdMap temp_map;
for (auto &kv: mem_id_map_) {
if (kv.second->RowCount() == 0) {
temp_map.insert(kv);
continue;//empty vector, no need to serialize
}
immu_mem_list_.push_back(kv.second);
}
mem_id_map_.swap(temp_map);
return Status::OK();
}
Status MemManager::Serialize(std::set<std::string> &table_ids) {
ToImmutable();
std::unique_lock<std::mutex> lock(serialization_mtx_);
std::string table_id;
table_ids.clear();
for (auto &mem : immu_mem_list_) {
mem->Serialize(table_id);
table_ids.insert(table_id);
}
immu_mem_list_.clear();
return Status::OK();
}
Status MemManager::EraseMemVector(const std::string &table_id) {
{//erase MemVector from rapid-insert cache
std::unique_lock<std::mutex> lock(mutex_);
mem_id_map_.erase(table_id);
}
{//erase MemVector from serialize cache
std::unique_lock<std::mutex> lock(serialization_mtx_);
MemList temp_list;
for (auto &mem : immu_mem_list_) {
if (mem->TableId() != table_id) {
temp_list.push_back(mem);
}
}
immu_mem_list_.swap(temp_list);
}
return Status::OK();
}
size_t MemManager::GetCurrentMutableMem() {
size_t totalMem = 0;
for (auto &kv : mem_id_map_) {
auto memVector = kv.second;
totalMem += memVector->Size();
}
return totalMem;
}
size_t MemManager::GetCurrentImmutableMem() {
size_t totalMem = 0;
for (auto &memVector : immu_mem_list_) {
totalMem += memVector->Size();
}
return totalMem;
}
size_t MemManager::GetCurrentMem() {
return GetCurrentMutableMem() + GetCurrentImmutableMem();
}
} // namespace engine
} // namespace milvus
} // namespace zilliz
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "ExecutionEngine.h"
#include "IDGenerator.h"
#include "Status.h"
#include "Meta.h"
#include "MemManagerAbstract.h"
#include <map>
#include <string>
#include <ctime>
#include <memory>
#include <mutex>
namespace zilliz {
namespace milvus {
namespace engine {
namespace meta {
class Meta;
}
class MemVectors {
public:
using MetaPtr = meta::Meta::Ptr;
using Ptr = std::shared_ptr<MemVectors>;
explicit MemVectors(const std::shared_ptr<meta::Meta> &,
const meta::TableFileSchema &, const Options &);
Status Add(size_t n_, const float *vectors_, IDNumbers &vector_ids_);
size_t RowCount() const;
size_t Size() const;
Status Serialize(std::string &table_id);
~MemVectors();
const std::string &Location() const { return schema_.location_; }
std::string TableId() const { return schema_.table_id_; }
private:
MemVectors() = delete;
MemVectors(const MemVectors &) = delete;
MemVectors &operator=(const MemVectors &) = delete;
MetaPtr meta_;
Options options_;
meta::TableFileSchema schema_;
IDGenerator *id_generator_;
ExecutionEnginePtr active_engine_;
}; // MemVectors
class MemManager : public MemManagerAbstract {
public:
using MetaPtr = meta::Meta::Ptr;
using MemVectorsPtr = typename MemVectors::Ptr;
using Ptr = std::shared_ptr<MemManager>;
MemManager(const std::shared_ptr<meta::Meta> &meta, const Options &options)
: meta_(meta), options_(options) {}
Status InsertVectors(const std::string &table_id,
size_t n, const float *vectors, IDNumbers &vector_ids) override;
Status Serialize(std::set<std::string> &table_ids) override;
Status EraseMemVector(const std::string &table_id) override;
size_t GetCurrentMutableMem() override;
size_t GetCurrentImmutableMem() override;
size_t GetCurrentMem() override;
private:
MemVectorsPtr GetMemByTable(const std::string &table_id);
Status InsertVectorsNoLock(const std::string &table_id,
size_t n, const float *vectors, IDNumbers &vector_ids);
Status ToImmutable();
using MemIdMap = std::map<std::string, MemVectorsPtr>;
using MemList = std::vector<MemVectorsPtr>;
MemIdMap mem_id_map_;
MemList immu_mem_list_;
MetaPtr meta_;
Options options_;
std::mutex mutex_;
std::mutex serialization_mtx_;
}; // MemManager
} // namespace engine
} // namespace milvus
} // namespace zilliz
///*******************************************************************************
// * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// * Unauthorized copying of this file, via any medium is strictly prohibited.
// * Proprietary and confidential.
// ******************************************************************************/
//#pragma once
//
//#include <boost/serialization/access.hpp>
//#include <string>
//#include <map>
//
//
//class GPUInfoFile {
// public:
// GPUInfoFile() = default;
//
// GPUInfoFile(const int &device_count, const std::map<int, std::string> &uuid_encryption_map)
// : device_count_(device_count), uuid_encryption_map_(uuid_encryption_map) {}
//
// int get_device_count() {
// return device_count_;
// }
// std::map<int, std::string> &get_uuid_encryption_map() {
// return uuid_encryption_map_;
// }
//
//
// public:
// friend class boost::serialization::access;
//
// template<typename Archive>
// void serialize(Archive &ar, const unsigned int version) {
// ar & device_count_;
// ar & uuid_encryption_map_;
// }
//
// public:
// int device_count_ = 0;
// std::map<int, std::string> uuid_encryption_map_;
//};
//
//class SerializedGPUInfoFile {
// public:
// ~SerializedGPUInfoFile() {
// if (gpu_info_file_ != nullptr) {
// delete (gpu_info_file_);
// gpu_info_file_ = nullptr;
// }
// }
//
// void
// set_gpu_info_file(GPUInfoFile *gpu_info_file) {
// gpu_info_file_ = gpu_info_file;
// }
//
// GPUInfoFile *get_gpu_info_file() {
// return gpu_info_file_;
// }
// private:
// friend class boost::serialization::access;
//
// template<typename Archive>
// void serialize(Archive &ar, const unsigned int version) {
// ar & gpu_info_file_;
// }
//
// private:
// GPUInfoFile *gpu_info_file_ = nullptr;
//};
//
//#include "utils/Log.h"
//#include "LicenseLibrary.h"
//#include "utils/Error.h"
//
//#include <iostream>
//#include <getopt.h>
//#include <memory.h>
//// Not provide path: current work path will be used and system.info.
//using namespace zilliz::milvus;
//
//void
//print_usage(const std::string &app_name) {
// printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
// printf(" Options:\n");
// printf(" -h --help Print this help\n");
// printf(" -s --sysinfo filename Generate system info file as given name\n");
// printf("\n");
//}
//
//int main(int argc, char *argv[]) {
// std::string app_name = argv[0];
// if (argc != 1 && argc != 3) {
// print_usage(app_name);
// return EXIT_FAILURE;
// }
//
// static struct option long_options[] = {{"system_info", required_argument, 0, 's'},
// {"help", no_argument, 0, 'h'},
// {NULL, 0, 0, 0}};
// int value = 0;
// int option_index = 0;
// std::string system_info_filename = "./system.info";
// while ((value = getopt_long(argc, argv, "s:h", long_options, &option_index)) != -1) {
// switch (value) {
// case 's': {
// char *system_info_filename_ptr = strdup(optarg);
// system_info_filename = system_info_filename_ptr;
// free(system_info_filename_ptr);
//// printf("Generate system info file: %s\n", system_info_filename.c_str());
// break;
// }
// case 'h':print_usage(app_name);
// return EXIT_SUCCESS;
// case '?':print_usage(app_name);
// return EXIT_FAILURE;
// default:print_usage(app_name);
// break;
// }
// }
//
// int device_count = 0;
// server::ServerError err = server::LicenseLibrary::GetDeviceCount(device_count);
// if (err != server::SERVER_SUCCESS) return -1;
//
// // 1. Get All GPU UUID
// std::vector<std::string> uuid_array;
// err = server::LicenseLibrary::GetUUID(device_count, uuid_array);
// if (err != server::SERVER_SUCCESS) return -1;
//
// // 2. Get UUID SHA256
// std::vector<std::string> uuid_sha256_array;
// err = server::LicenseLibrary::GetUUIDSHA256(device_count, uuid_array, uuid_sha256_array);
// if (err != server::SERVER_SUCCESS) return -1;
//
// // 3. Generate GPU ID map with GPU UUID
// std::map<int, std::string> uuid_encrption_map;
// for (int i = 0; i < device_count; ++i) {
// uuid_encrption_map[i] = uuid_sha256_array[i];
// }
//
//
// // 4. Generate GPU_info File
// err = server::LicenseLibrary::GPUinfoFileSerialization(system_info_filename,
// device_count,
// uuid_encrption_map);
// if (err != server::SERVER_SUCCESS) return -1;
//
// printf("Generate GPU_info File Success\n");
//
//
// return 0;
//}
\ No newline at end of file
//#include "LicenseCheck.h"
//#include <iostream>
//#include <thread>
//
//#include <boost/archive/binary_oarchive.hpp>
//#include <boost/archive/binary_iarchive.hpp>
////#include <boost/foreach.hpp>
////#include <boost/serialization/vector.hpp>
//#include <boost/filesystem/path.hpp>
//#include <boost/serialization/map.hpp>
//#include <boost/filesystem/operations.hpp>
//#include <boost/thread.hpp>
//#include <boost/date_time/posix_time/posix_time.hpp>
//
//
//namespace zilliz {
//namespace milvus {
//namespace server {
//
//LicenseCheck::LicenseCheck() {
//
//}
//
//LicenseCheck::~LicenseCheck() {
// StopCountingDown();
//}
//
//ServerError
//LicenseCheck::LegalityCheck(const std::string &license_file_path) {
//
// int device_count;
// LicenseLibrary::GetDeviceCount(device_count);
// std::vector<std::string> uuid_array;
// LicenseLibrary::GetUUID(device_count, uuid_array);
//
// std::vector<std::string> sha_array;
// LicenseLibrary::GetUUIDSHA256(device_count, uuid_array, sha_array);
//
// int output_device_count;
// std::map<int, std::string> uuid_encryption_map;
// time_t starting_time;
// time_t end_time;
// ServerError err = LicenseLibrary::LicenseFileDeserialization(license_file_path,
// output_device_count,
// uuid_encryption_map,
// starting_time,
// end_time);
// if(err !=SERVER_SUCCESS)
// {
// std::cout << "License check error: 01" << std::endl;
// return SERVER_UNEXPECTED_ERROR;
// }
// time_t system_time;
// LicenseLibrary::GetSystemTime(system_time);
//
// if (device_count != output_device_count) {
// std::cout << "License check error: 02" << std::endl;
// return SERVER_UNEXPECTED_ERROR;
// }
// for (int i = 0; i < device_count; ++i) {
// if (sha_array[i] != uuid_encryption_map[i]) {
// std::cout << "License check error: 03" << std::endl;
// return SERVER_UNEXPECTED_ERROR;
// }
// }
// if (system_time < starting_time || system_time > end_time) {
// std::cout << "License check error: 04" << std::endl;
// return SERVER_UNEXPECTED_ERROR;
// }
// std::cout << "Legality Check Success" << std::endl;
// return SERVER_SUCCESS;
//}
//
//// Part 2: Timing check license
//
//ServerError
//LicenseCheck::AlterFile(const std::string &license_file_path,
// const boost::system::error_code &ec,
// boost::asio::deadline_timer *pt) {
//
// ServerError err = LicenseCheck::LegalityCheck(license_file_path);
// if(err!=SERVER_SUCCESS) {
// std::cout << "license file check error" << std::endl;
// exit(1);
// }
//
// std::cout << "---runing---" << std::endl;
// pt->expires_at(pt->expires_at() + boost::posix_time::hours(1));
// pt->async_wait(boost::bind(LicenseCheck::AlterFile, license_file_path, boost::asio::placeholders::error, pt));
//
// return SERVER_SUCCESS;
//
//}
//
//ServerError
//LicenseCheck::StartCountingDown(const std::string &license_file_path) {
//
// if (!LicenseLibrary::IsFileExistent(license_file_path)) {
// std::cout << "license file not exist" << std::endl;
// exit(1);
// }
//
// //create a thread to run AlterFile
// if(counting_thread_ == nullptr) {
// counting_thread_ = std::make_shared<std::thread>([&]() {
// boost::asio::deadline_timer t(io_service_, boost::posix_time::hours(1));
// t.async_wait(boost::bind(LicenseCheck::AlterFile, license_file_path, boost::asio::placeholders::error, &t));
// io_service_.run();//this thread will block here
// });
// }
//
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseCheck::StopCountingDown() {
// if(!io_service_.stopped()) {
// io_service_.stop();
// }
//
// if(counting_thread_ != nullptr) {
// counting_thread_->join();
// counting_thread_ = nullptr;
// }
//
// return SERVER_SUCCESS;
//}
//
//}
//}
//}
\ No newline at end of file
//#pragma once
//
//#include "utils/Error.h"
//#include "LicenseLibrary.h"
//
//#include <boost/asio.hpp>
//
//#include <thread>
//#include <memory>
//
//namespace zilliz {
//namespace milvus {
//namespace server {
//
//class LicenseCheck {
//private:
// LicenseCheck();
// ~LicenseCheck();
//
//public:
// static LicenseCheck &
// GetInstance() {
// static LicenseCheck instance;
// return instance;
// };
//
// static ServerError
// LegalityCheck(const std::string &license_file_path);
//
// ServerError
// StartCountingDown(const std::string &license_file_path);
//
// ServerError
// StopCountingDown();
//
//private:
// static ServerError
// AlterFile(const std::string &license_file_path,
// const boost::system::error_code &ec,
// boost::asio::deadline_timer *pt);
//
//private:
// boost::asio::io_service io_service_;
// std::shared_ptr<std::thread> counting_thread_;
//
//};
//
//}
//}
//}
//
//
///*******************************************************************************
// * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// * Unauthorized copying of this file, via any medium is strictly prohibited.
// * Proprietary and confidential.
// ******************************************************************************/
//#pragma once
//
//
//#include <boost/serialization/access.hpp>
//#include <string>
//#include <map>
//
//
//class LicenseFile {
// public:
// LicenseFile() = default;
//
// LicenseFile(const int &device_count,
// const std::map<int, std::string> &uuid_encryption_map,
// const time_t &starting_time,
// const time_t &end_time)
// : device_count_(device_count),
// uuid_encryption_map_(uuid_encryption_map),
// starting_time_(starting_time),
// end_time_(end_time) {}
//
// int get_device_count() {
// return device_count_;
// }
// std::map<int, std::string> &get_uuid_encryption_map() {
// return uuid_encryption_map_;
// }
// time_t get_starting_time() {
// return starting_time_;
// }
// time_t get_end_time() {
// return end_time_;
// }
//
// public:
// friend class boost::serialization::access;
//
// template<typename Archive>
// void serialize(Archive &ar, const unsigned int version) {
// ar & device_count_;
// ar & uuid_encryption_map_;
// ar & starting_time_;
// ar & end_time_;
// }
//
// public:
// int device_count_ = 0;
// std::map<int, std::string> uuid_encryption_map_;
// time_t starting_time_ = 0;
// time_t end_time_ = 0;
//};
//
//class SerializedLicenseFile {
// public:
// ~SerializedLicenseFile() {
// if (license_file_ != nullptr) {
// delete (license_file_);
// license_file_ = nullptr;
// }
// }
//
// void
// set_license_file(LicenseFile *license_file) {
// license_file_ = license_file;
// }
//
// LicenseFile *get_license_file() {
// return license_file_;
// }
// private:
// friend class boost::serialization::access;
//
// template<typename Archive>
// void serialize(Archive &ar, const unsigned int version) {
// ar & license_file_;
// }
//
// private:
// LicenseFile *license_file_ = nullptr;
//};
//
//
//#include <iostream>
//#include <getopt.h>
//#include <memory.h>
//
//#include "utils/Log.h"
//#include "license/LicenseLibrary.h"
//#include "utils/Error.h"
//
//
//using namespace zilliz::milvus;
//// Not provide path: current work path will be used and system.info.
//
//void
//print_usage(const std::string &app_name) {
// printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
// printf(" Options:\n");
// printf(" -h --help Print this help\n");
// printf(" -s --sysinfo filename sysinfo file location\n");
// printf(" -l --license filename Generate license file as given name\n");
// printf(" -b --starting time Set start time (format: YYYY-MM-DD)\n");
// printf(" -e --end time Set end time (format: YYYY-MM-DD)\n");
// printf("\n");
//}
//
//int main(int argc, char *argv[]) {
// std::string app_name = argv[0];
//// if (argc != 1 && argc != 3) {
//// print_usage(app_name);
//// return EXIT_FAILURE;
//// }
// static struct option long_options[] = {{"system_info", required_argument, 0, 's'},
// {"license", optional_argument, 0, 'l'},
// {"help", no_argument, 0, 'h'},
// {"starting_time", required_argument, 0, 'b'},
// {"end_time", required_argument, 0, 'e'},
// {NULL, 0, 0, 0}};
// server::ServerError err;
// int value = 0;
// int option_index = 0;
// std::string system_info_filename = "./system.info";
// std::string license_filename = "./system.license";
// char *string_starting_time = NULL;
// char *string_end_time = NULL;
// time_t starting_time = 0;
// time_t end_time = 0;
// int flag_s = 1;
// int flag_b = 1;
// int flag_e = 1;
// while ((value = getopt_long(argc, argv, "hl:s:b:e:", long_options, NULL)) != -1) {
// switch (value) {
// case 's': {
// flag_s = 0;
// system_info_filename = (std::string) (optarg);
// break;
// }
// case 'b': {
// flag_b = 0;
// string_starting_time = optarg;
// break;
// }
// case 'e': {
// flag_e = 0;
// string_end_time = optarg;
// break;
// }
// case 'l': {
// license_filename = (std::string) (optarg);
// break;
// }
// case 'h':print_usage(app_name);
// return EXIT_SUCCESS;
// case '?':print_usage(app_name);
// return EXIT_FAILURE;
// default:print_usage(app_name);
// break;
// }
//
// }
// if (flag_s) {
// printf("Error: sysinfo file location must be entered\n");
// return 1;
// }
// if (flag_b) {
// printf("Error: start time must be entered\n");
// return 1;
// }
// if (flag_e) {
// printf("Error: end time must be entered\n");
// return 1;
// }
//
// err = server::LicenseLibrary::GetDateTime(string_starting_time, starting_time);
// if (err != server::SERVER_SUCCESS) return -1;
//
// err = server::LicenseLibrary::GetDateTime(string_end_time, end_time);
// if (err != server::SERVER_SUCCESS) return -1;
//
//
// int output_info_device_count = 0;
// std::map<int, std::string> output_info_uuid_encrption_map;
//
//
// err = server::LicenseLibrary::GPUinfoFileDeserialization(system_info_filename,
// output_info_device_count,
// output_info_uuid_encrption_map);
// if (err != server::SERVER_SUCCESS) return -1;
//
//
// err = server::LicenseLibrary::LicenseFileSerialization(license_filename,
// output_info_device_count,
// output_info_uuid_encrption_map,
// starting_time,
// end_time);
// if (err != server::SERVER_SUCCESS) return -1;
//
//
// printf("Generate License File Success\n");
//
// return 0;
//}
//////////////////////////////////////////////////////////////////////////////////
//// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
//// Unauthorized copying of this file, via any medium is strictly prohibited.
//// Proprietary and confidential.
//////////////////////////////////////////////////////////////////////////////////
//
//#include "LicenseLibrary.h"
//#include "utils/Log.h"
//#include <cuda_runtime.h>
//#include <nvml.h>
//#include <openssl/md5.h>
//#include <openssl/sha.h>
//
//#include <boost/archive/binary_oarchive.hpp>
//#include <boost/archive/binary_iarchive.hpp>
////#include <boost/foreach.hpp>
////#include <boost/serialization/vector.hpp>
//#include <boost/filesystem/path.hpp>
//#include <boost/serialization/map.hpp>
//#include <boost/filesystem/operations.hpp>
//
//
//namespace zilliz {
//namespace milvus {
//namespace server {
//
//constexpr int LicenseLibrary::sha256_length_;
//
//// Part 0: File check
//bool
//LicenseLibrary::IsFileExistent(const std::string &path) {
//
// boost::system::error_code error;
// auto file_status = boost::filesystem::status(path, error);
// if (error) {
// return false;
// }
//
// if (!boost::filesystem::exists(file_status)) {
// return false;
// }
//
// return !boost::filesystem::is_directory(file_status);
//}
//
//// Part 1: Get GPU Info
//ServerError
//LicenseLibrary::GetDeviceCount(int &device_count) {
// nvmlReturn_t result = nvmlInit();
// if (NVML_SUCCESS != result) {
// printf("Failed to initialize NVML: %s\n", nvmlErrorString(result));
// return SERVER_UNEXPECTED_ERROR;
// }
// cudaError_t error_id = cudaGetDeviceCount(&device_count);
// if (error_id != cudaSuccess) {
// printf("cudaGetDeviceCount returned %d\n-> %s\n", (int) error_id, cudaGetErrorString(error_id));
// printf("Result = FAIL\n");
// return SERVER_UNEXPECTED_ERROR;
// }
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseLibrary::GetUUID(int device_count, std::vector<std::string> &uuid_array) {
// if (device_count == 0) {
// printf("There are no available device(s) that support CUDA\n");
// return SERVER_UNEXPECTED_ERROR;
// }
//
// for (int dev = 0; dev < device_count; ++dev) {
// nvmlDevice_t device;
// nvmlReturn_t result = nvmlDeviceGetHandleByIndex(dev, &device);
// if (NVML_SUCCESS != result) {
// printf("Failed to get handle for device %i: %s\n", dev, nvmlErrorString(result));
// return SERVER_UNEXPECTED_ERROR;
// }
//
// char uuid[80];
// unsigned int length = 80;
// nvmlReturn_t err = nvmlDeviceGetUUID(device, uuid, length);
// if (err != NVML_SUCCESS) {
// printf("nvmlDeviceGetUUID error: %d\n", err);
// return SERVER_UNEXPECTED_ERROR;
// }
//
// uuid_array.emplace_back(uuid);
// }
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseLibrary::GetUUIDMD5(int device_count,
// std::vector<std::string> &uuid_array,
// std::vector<std::string> &md5_array) {
// MD5_CTX ctx;
// unsigned char outmd[16];
// char temp[2];
// std::string md5;
// for (int dev = 0; dev < device_count; ++dev) {
// md5.clear();
// memset(outmd, 0, sizeof(outmd));
// MD5_Init(&ctx);
// MD5_Update(&ctx, uuid_array[dev].c_str(), uuid_array[dev].size());
// MD5_Final(outmd, &ctx);
// for (int i = 0; i < 16; ++i) {
// std::snprintf(temp, 2, "%02X", outmd[i]);
// md5 += temp;
// }
// md5_array.push_back(md5);
// }
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseLibrary::GetUUIDSHA256(const int &device_count,
// std::vector<std::string> &uuid_array,
// std::vector<std::string> &sha_array) {
// SHA256_CTX ctx;
// unsigned char outmd[sha256_length_];
// char temp[2];
// std::string sha;
// for (int dev = 0; dev < device_count; ++dev) {
// sha.clear();
// memset(outmd, 0, sizeof(outmd));
// SHA256_Init(&ctx);
// SHA256_Update(&ctx, uuid_array[dev].c_str(), uuid_array[dev].size());
// SHA256_Final(outmd, &ctx);
// for (int i = 0; i < sha256_length_; ++i) {
// std::snprintf(temp, 2, "%02X", outmd[i]);
// sha += temp;
// }
// sha_array.push_back(sha);
// }
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseLibrary::GetSystemTime(time_t &system_time) {
// system_time = time(NULL);
// return SERVER_SUCCESS;
//}
//
//// Part 2: Handle License File
//ServerError
//LicenseLibrary::LicenseFileSerialization(const std::string &path,
// int device_count,
// const std::map<int, std::string> &uuid_encrption_map,
// time_t starting_time,
// time_t end_time) {
//
// std::ofstream file(path);
// boost::archive::binary_oarchive oa(file);
// oa.register_type<LicenseFile>();
//
// SerializedLicenseFile serialized_license_file;
//
// serialized_license_file.set_license_file(new LicenseFile(device_count,
// uuid_encrption_map,
// starting_time,
// end_time));
// oa << serialized_license_file;
//
// file.close();
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseLibrary::LicenseFileDeserialization(const std::string &path,
// int &device_count,
// std::map<int, std::string> &uuid_encrption_map,
// time_t &starting_time,
// time_t &end_time) {
// if (!IsFileExistent(path)) return SERVER_LICENSE_FILE_NOT_EXIST;
// std::ifstream file(path);
// boost::archive::binary_iarchive ia(file);
// ia.register_type<LicenseFile>();
//
// SerializedLicenseFile serialized_license_file;
// ia >> serialized_license_file;
//
// device_count = serialized_license_file.get_license_file()->get_device_count();
// uuid_encrption_map = serialized_license_file.get_license_file()->get_uuid_encryption_map();
// starting_time = serialized_license_file.get_license_file()->get_starting_time();
// end_time = serialized_license_file.get_license_file()->get_end_time();
//
// file.close();
// return SERVER_SUCCESS;
//}
//
////ServerError
////LicenseLibrary::SecretFileSerialization(const std::string &path,
//// const time_t &update_time,
//// const off_t &file_size,
//// const time_t &starting_time,
//// const time_t &end_time,
//// const std::string &file_md5) {
//// std::ofstream file(path);
//// boost::archive::binary_oarchive oa(file);
//// oa.register_type<SecretFile>();
////
//// SerializedSecretFile serialized_secret_file;
////
//// serialized_secret_file.set_secret_file(new SecretFile(update_time, file_size, starting_time, end_time, file_md5));
//// oa << serialized_secret_file;
////
//// file.close();
//// return SERVER_SUCCESS;
////}
////
////ServerError
////LicenseLibrary::SecretFileDeserialization(const std::string &path,
//// time_t &update_time,
//// off_t &file_size,
//// time_t &starting_time,
//// time_t &end_time,
//// std::string &file_md5) {
//// if (!IsFileExistent(path)) return SERVER_LICENSE_FILE_NOT_EXIST;
////
//// std::ifstream file(path);
//// boost::archive::binary_iarchive ia(file);
//// ia.register_type<SecretFile>();
//// SerializedSecretFile serialized_secret_file;
////
//// ia >> serialized_secret_file;
//// update_time = serialized_secret_file.get_secret_file()->get_update_time();
//// file_size = serialized_secret_file.get_secret_file()->get_file_size();
//// starting_time = serialized_secret_file.get_secret_file()->get_starting_time();
//// end_time = serialized_secret_file.get_secret_file()->get_end_time();
//// file_md5 = serialized_secret_file.get_secret_file()->get_file_md5();
//// file.close();
//// return SERVER_SUCCESS;
////}
//
//
//
//// Part 3: File attribute: UpdateTime Time/ Size/ MD5
//ServerError
//LicenseLibrary::GetFileUpdateTimeAndSize(const std::string &path, time_t &update_time, off_t &file_size) {
//
// if (!IsFileExistent(path)) return SERVER_LICENSE_FILE_NOT_EXIST;
//
// struct stat buf;
// int err_no = stat(path.c_str(), &buf);
// if (err_no != 0) {
// std::cout << strerror(err_no) << std::endl;
// return SERVER_UNEXPECTED_ERROR;
// }
//
// update_time = buf.st_mtime;
// file_size = buf.st_size;
//
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseLibrary::GetFileMD5(const std::string &path, std::string &filemd5) {
//
// if (!IsFileExistent(path)) return SERVER_LICENSE_FILE_NOT_EXIST;
//
// filemd5.clear();
//
// std::ifstream file(path.c_str(), std::ifstream::binary);
// if (!file) {
// return -1;
// }
//
// MD5_CTX md5Context;
// MD5_Init(&md5Context);
//
// char buf[1024 * 16];
// while (file.good()) {
// file.read(buf, sizeof(buf));
// MD5_Update(&md5Context, buf, file.gcount());
// }
//
// unsigned char result[MD5_DIGEST_LENGTH];
// MD5_Final(result, &md5Context);
//
// char hex[35];
// memset(hex, 0, sizeof(hex));
// for (int i = 0; i < MD5_DIGEST_LENGTH; ++i) {
// sprintf(hex + i * 2, "%02X", result[i]);
// }
// hex[32] = '\0';
// filemd5 = std::string(hex);
//
// return SERVER_SUCCESS;
//}
//// Part 4: GPU Info File Serialization/Deserialization
//ServerError
//LicenseLibrary::GPUinfoFileSerialization(const std::string &path,
// int device_count,
// const std::map<int, std::string> &uuid_encrption_map) {
// std::ofstream file(path);
// boost::archive::binary_oarchive oa(file);
// oa.register_type<GPUInfoFile>();
//
// SerializedGPUInfoFile serialized_gpu_info_file;
//
// serialized_gpu_info_file.set_gpu_info_file(new GPUInfoFile(device_count, uuid_encrption_map));
// oa << serialized_gpu_info_file;
//
// file.close();
// return SERVER_SUCCESS;
//}
//ServerError
//LicenseLibrary::GPUinfoFileDeserialization(const std::string &path,
// int &device_count,
// std::map<int, std::string> &uuid_encrption_map) {
// if (!IsFileExistent(path)) return SERVER_LICENSE_FILE_NOT_EXIST;
//
// std::ifstream file(path);
// boost::archive::binary_iarchive ia(file);
// ia.register_type<GPUInfoFile>();
//
// SerializedGPUInfoFile serialized_gpu_info_file;
// ia >> serialized_gpu_info_file;
//
// device_count = serialized_gpu_info_file.get_gpu_info_file()->get_device_count();
// uuid_encrption_map = serialized_gpu_info_file.get_gpu_info_file()->get_uuid_encryption_map();
//
// file.close();
// return SERVER_SUCCESS;
//}
//
//ServerError
//LicenseLibrary::GetDateTime(const char *cha, time_t &data_time) {
// tm tm_;
// int year, month, day;
// sscanf(cha, "%d-%d-%d", &year, &month, &day);
// tm_.tm_year = year - 1900;
// tm_.tm_mon = month - 1;
// tm_.tm_mday = day;
// tm_.tm_hour = 0;
// tm_.tm_min = 0;
// tm_.tm_sec = 0;
// tm_.tm_isdst = 0;
// data_time = mktime(&tm_);
// return SERVER_SUCCESS;
//
//}
//
//}
//}
//}
\ No newline at end of file
//#pragma once
//
//#include "LicenseFile.h"
//#include "GPUInfoFile.h"
//
//#include "utils/Error.h"
//
//#include <boost/asio.hpp>
//#include <boost/thread.hpp>
//#include <boost/date_time/posix_time/posix_time.hpp>
//
//#include <vector>
//#include <map>
//#include <time.h>
//
//
//namespace zilliz {
//namespace milvus {
//namespace server {
//
//class LicenseLibrary {
// public:
// // Part 0: File check
// static bool
// IsFileExistent(const std::string &path);
//
// // Part 1: Get GPU Info
// static ServerError
// GetDeviceCount(int &device_count);
//
// static ServerError
// GetUUID(int device_count, std::vector<std::string> &uuid_array);
//
// static ServerError
// GetUUIDMD5(int device_count, std::vector<std::string> &uuid_array, std::vector<std::string> &md5_array);
//
//
// static ServerError
// GetUUIDSHA256(const int &device_count,
// std::vector<std::string> &uuid_array,
// std::vector<std::string> &sha_array);
//
// static ServerError
// GetSystemTime(time_t &system_time);
//
// // Part 2: Handle License File
// static ServerError
// LicenseFileSerialization(const std::string &path,
// int device_count,
// const std::map<int, std::string> &uuid_encrption_map,
// time_t starting_time,
// time_t end_time);
//
// static ServerError
// LicenseFileDeserialization(const std::string &path,
// int &device_count,
// std::map<int, std::string> &uuid_encrption_map,
// time_t &starting_time,
// time_t &end_time);
//
//// static ServerError
//// SecretFileSerialization(const std::string &path,
//// const time_t &update_time,
//// const off_t &file_size,
//// const time_t &starting_time,
//// const time_t &end_time,
//// const std::string &file_md5);
////
//// static ServerError
//// SecretFileDeserialization(const std::string &path,
//// time_t &update_time,
//// off_t &file_size,
//// time_t &starting_time,
//// time_t &end_time,
//// std::string &file_md5);
//
// // Part 3: File attribute: UpdateTime Time/ Size/ MD5
// static ServerError
// GetFileUpdateTimeAndSize(const std::string &path, time_t &update_time, off_t &file_size);
//
// static ServerError
// GetFileMD5(const std::string &path, std::string &filemd5);
//
// // Part 4: GPU Info File Serialization/Deserialization
// static ServerError
// GPUinfoFileSerialization(const std::string &path,
// int device_count,
// const std::map<int, std::string> &uuid_encrption_map);
// static ServerError
// GPUinfoFileDeserialization(const std::string &path,
// int &device_count,
// std::map<int, std::string> &uuid_encrption_map);
//
// static ServerError
// GetDateTime(const char *cha, time_t &data_time);
//
//
// private:
// static constexpr int sha256_length_ = 32;
//};
//
//
//}
//}
//}
......@@ -36,8 +36,6 @@ ClientProxy::Connect(const ConnectParam &param) {
}
}
Status
ClientProxy::Connect(const std::string &uri) {
if (!UriCheck(uri)) {
......
......@@ -113,7 +113,8 @@ class Connection {
* @return Indicate if connect is successful
*/
virtual Status Connect(const ConnectParam &param) = 0;
virtual Status
Connect(const ConnectParam &param) = 0;
/**
* @brief Connect
......@@ -125,7 +126,8 @@ class Connection {
*
* @return Indicate if connect is successful
*/
virtual Status Connect(const std::string &uri) = 0;
virtual Status
Connect(const std::string &uri) = 0;
/**
* @brief connected
......@@ -134,7 +136,8 @@ class Connection {
*
* @return Indicate if connection status
*/
virtual Status Connected() const = 0;
virtual Status
Connected() const = 0;
/**
* @brief Disconnect
......@@ -143,7 +146,8 @@ class Connection {
*
* @return Indicate if disconnect is successful
*/
virtual Status Disconnect() = 0;
virtual Status
Disconnect() = 0;
/**
......@@ -155,7 +159,8 @@ class Connection {
*
* @return Indicate if table is created successfully
*/
virtual Status CreateTable(const TableSchema &param) = 0;
virtual Status
CreateTable(const TableSchema &param) = 0;
/**
......@@ -167,7 +172,8 @@ class Connection {
*
* @return Indicate if table is cexist
*/
virtual bool HasTable(const std::string &table_name) = 0;
virtual bool
HasTable(const std::string &table_name) = 0;
/**
......@@ -179,9 +185,11 @@ class Connection {
*
* @return Indicate if table is delete successfully.
*/
virtual Status DropTable(const std::string &table_name) = 0;
virtual Status
DropTable(const std::string &table_name) = 0;
virtual Status DeleteTable(const std::string &table_name) = 0;
virtual Status
DeleteTable(const std::string &table_name) = 0;
/**
......@@ -193,7 +201,8 @@ class Connection {
*
* @return Indicate if build index successfully.
*/
virtual Status BuildIndex(const std::string &table_name) = 0;
virtual Status
BuildIndex(const std::string &table_name) = 0;
/**
* @brief Add vector to table
......@@ -206,11 +215,13 @@ class Connection {
*
* @return Indicate if vector array are inserted successfully
*/
virtual Status InsertVector(const std::string &table_name,
virtual Status
InsertVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) = 0;
virtual Status AddVector(const std::string &table_name,
virtual Status
AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) = 0;
......@@ -228,7 +239,8 @@ class Connection {
*
* @return Indicate if query is successful.
*/
virtual Status SearchVector(const std::string &table_name,
virtual Status
SearchVector(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
......@@ -244,7 +256,8 @@ class Connection {
*
* @return Indicate if this operation is successful.
*/
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) = 0;
virtual Status
DescribeTable(const std::string &table_name, TableSchema &table_schema) = 0;
/**
* @brief Get table row count
......@@ -256,7 +269,8 @@ class Connection {
*
* @return Indicate if this operation is successful.
*/
virtual Status GetTableRowCount(const std::string &table_name, int64_t &row_count) = 0;
virtual Status
GetTableRowCount(const std::string &table_name, int64_t &row_count) = 0;
/**
* @brief Show all tables in database
......@@ -267,7 +281,8 @@ class Connection {
*
* @return Indicate if this operation is successful.
*/
virtual Status ShowTables(std::vector<std::string> &table_array) = 0;
virtual Status
ShowTables(std::vector<std::string> &table_array) = 0;
/**
* @brief Give the client version
......@@ -276,7 +291,8 @@ class Connection {
*
* @return Client version.
*/
virtual std::string ClientVersion() const = 0;
virtual std::string
ClientVersion() const = 0;
/**
* @brief Give the server version
......@@ -285,7 +301,8 @@ class Connection {
*
* @return Server version.
*/
virtual std::string ServerVersion() const = 0;
virtual std::string
ServerVersion() const = 0;
/**
* @brief Give the server status
......@@ -294,7 +311,8 @@ class Connection {
*
* @return Server status.
*/
virtual std::string ServerStatus() const = 0;
virtual std::string
ServerStatus() const = 0;
};
}
\ No newline at end of file
......@@ -4,7 +4,6 @@
* Proprietary and confidential.
******************************************************************************/
#include "ConnectionImpl.h"
#include "version.h"
namespace milvus {
......
......@@ -19,67 +19,67 @@ public:
ConnectionImpl();
// Implementations of the Connection interface
virtual
Status Connect(const ConnectParam &param) override;
virtual Status
Connect(const ConnectParam &param) override;
virtual
Status Connect(const std::string &uri) override;
virtual Status
Connect(const std::string &uri) override;
virtual
Status Connected() const override;
virtual Status
Connected() const override;
virtual
Status Disconnect() override;
virtual Status
Disconnect() override;
virtual
Status CreateTable(const TableSchema &param) override;
virtual Status
CreateTable(const TableSchema &param) override;
virtual
bool HasTable(const std::string &table_name) override;
virtual
Status DropTable(const std::string &table_name) override;
virtual Status
DropTable(const std::string &table_name) override;
virtual
Status DeleteTable(const std::string &table_name) override;
virtual Status
DeleteTable(const std::string &table_name) override;
virtual
Status BuildIndex(const std::string &table_name) override;
virtual Status
BuildIndex(const std::string &table_name) override;
virtual
Status InsertVector(const std::string &table_name,
virtual Status
InsertVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual
Status AddVector(const std::string &table_name,
virtual Status
AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual
Status SearchVector(const std::string &table_name,
virtual Status
SearchVector(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual
Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status
DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual
Status GetTableRowCount(const std::string &table_name, int64_t &row_count) override;
virtual Status
GetTableRowCount(const std::string &table_name, int64_t &row_count) override;
virtual
Status ShowTables(std::vector<std::string> &table_array) override;
virtual Status
ShowTables(std::vector<std::string> &table_array) override;
virtual
std::string ClientVersion() const override;
virtual std::string
ClientVersion() const override;
virtual
std::string ServerVersion() const override;
virtual std::string
ServerVersion() const override;
virtual
std::string ServerStatus() const override;
virtual std::string
ServerStatus() const override;
private:
std::shared_ptr<ClientProxy> client_proxy_;
......
......@@ -14,7 +14,6 @@
#include "utils/Log.h"
#include "utils/SignalUtil.h"
#include "utils/TimeRecorder.h"
#include "license/LicenseCheck.h"
#include "metrics/Metrics.h"
#include <fcntl.h>
......@@ -157,19 +156,6 @@ Server::Start() {
ServerConfig &config = ServerConfig::GetInstance();
ConfigNode server_config = config.GetConfig(CONFIG_SERVER);
#ifdef ENABLE_LICENSE
ConfigNode license_config = config.GetConfig(CONFIG_LICENSE);
std::string license_file_path = license_config.GetValue(CONFIG_LICENSE_PATH);
SERVER_LOG_INFO << "License path: " << license_file_path;
if(server::LicenseCheck::LegalityCheck(license_file_path) != SERVER_SUCCESS) {
SERVER_LOG_ERROR << "License check failed";
exit(1);
}
server::LicenseCheck::GetInstance().StartCountingDown(license_file_path);
#endif
// Handle Signal
signal(SIGINT, SignalUtil::HandleSignal);
signal(SIGHUP, SignalUtil::HandleSignal);
......@@ -221,9 +207,6 @@ Server::Stop() {
StopService();
#ifdef ENABLE_LICENSE
server::LicenseCheck::GetInstance().StopCountingDown();
#endif
std::cout << "Milvus server is closed!" << std::endl;
}
......
......@@ -21,6 +21,7 @@ public:
*
* This method is used to create table
*
* @param context, add context for every RPC
* @param request, used to provide table information to be created.
* @param response, used to get the status
*
......@@ -28,6 +29,7 @@ public:
*
* @param request
* @param response
* @param context
*/
::grpc::Status
CreateTable(::grpc::ServerContext* context,
......@@ -38,6 +40,7 @@ public:
*
* This method is used to test table existence.
*
* @param context, add context for every RPC
* @param request, table name is going to be tested.
* @param response, get the bool reply of hastable
*
......@@ -45,6 +48,7 @@ public:
*
* @param request
* @param response
* @param context
*/
::grpc::Status
HasTable(::grpc::ServerContext* context,
......@@ -55,6 +59,7 @@ public:
*
* This method is used to drop table.
*
* @param context, add context for every RPC
* @param request, table name is going to be deleted.
* @param response, get the status of droptable
*
......@@ -62,6 +67,7 @@ public:
*
* @param request
* @param response
* @param context
*/
::grpc::Status
DropTable(::grpc::ServerContext* context,
......@@ -72,6 +78,7 @@ public:
*
* This method is used to build index by table in sync.
*
* @param context, add context for every RPC
* @param request, table name is going to be built index.
* @param response, get the status of buildindex
*
......@@ -79,6 +86,7 @@ public:
*
* @param request
* @param response
* @param context
*/
::grpc::Status
BuildIndex(::grpc::ServerContext* context,
......@@ -90,11 +98,13 @@ public:
*
* This method is used to insert vector array to table.
*
* @param context, add context for every RPC
* @param request, table_name is inserted.
* @param response, vector array is inserted.
*
* @return status
*
* @param context
* @param request
* @param response
*/
......@@ -107,6 +117,7 @@ public:
*
* This method is used to query vector in table.
*
* @param context, add context for every RPC
* @param request:
* table_name, table_name is queried.
* query_record_array, all vector are going to be queried.
......@@ -116,6 +127,8 @@ public:
* @param writer, write query result array.
*
* @return status
*
* @param context
* @param request
* @param writer
*/
......@@ -128,6 +141,7 @@ public:
*
* This method is used to query vector in specified files.
*
* @param context, add context for every RPC
* @param request:
* file_id_array, specified files id array, queried.
* query_record_array, all vector are going to be queried.
......@@ -138,6 +152,7 @@ public:
*
* @return status
*
* @param context
* @param request
* @param writer
*/
......@@ -150,11 +165,15 @@ public:
*
* This method is used to get table schema.
*
* @param table_name, target table name.
* @param context, add context for every RPC
* @param request, target table name.
* @param response, table schema
*
* @return table schema
* @return status
*
* @param table_name
* @param context
* @param request
* @param response
*/
::grpc::Status
DescribeTable(::grpc::ServerContext* context,
......@@ -165,11 +184,15 @@ public:
*
* This method is used to get table row count.
*
* @param table_name, target table name.
* @param context, add context for every RPC
* @param request, target table name.
* @param response, table row count
*
* @return table row count
*
* @param table_name
* @param request
* @param response
* @param context
*/
::grpc::Status
GetTableRowCount(::grpc::ServerContext* context,
......@@ -180,8 +203,15 @@ public:
*
* This method is used to list all tables.
*
* @param context, add context for every RPC
* @param request, show table command, usually not use
* @param writer, write tables to client
*
* @return status
*
* @param context
* @param request
* @param writer
*/
::grpc::Status
ShowTables(::grpc::ServerContext* context,
......@@ -192,10 +222,15 @@ public:
*
*
* This method is used to give the server status.
* @param context, add context for every RPC
* @param request, give server command
* @param response, server status
*
* @return status
*
* @param cmd
* @param context
* @param request
* @param response
*/
::grpc::Status
Ping(::grpc::ServerContext* context,
......
......@@ -62,14 +62,16 @@ BaseTask::~BaseTask() {
WaitToFinish();
}
ServerError BaseTask::Execute() {
ServerError
BaseTask::Execute() {
error_code_ = OnExecute();
done_ = true;
finish_cond_.notify_all();
return error_code_;
}
ServerError BaseTask::SetError(ServerError error_code, const std::string& error_msg) {
ServerError
BaseTask::SetError(ServerError error_code, const std::string& error_msg) {
error_code_ = error_code;
error_msg_ = error_msg;
......@@ -77,7 +79,8 @@ ServerError BaseTask::SetError(ServerError error_code, const std::string& error_
return error_code_;
}
ServerError BaseTask::WaitToFinish() {
ServerError
BaseTask::WaitToFinish() {
std::unique_lock <std::mutex> lock(finish_mtx_);
finish_cond_.wait(lock, [this] { return done_; });
......@@ -94,7 +97,8 @@ RequestScheduler::~RequestScheduler() {
Stop();
}
void RequestScheduler::ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status *grpc_status) {
void
RequestScheduler::ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status *grpc_status) {
if(task_ptr == nullptr) {
return;
}
......@@ -112,7 +116,8 @@ void RequestScheduler::ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status *g
}
}
void RequestScheduler::Start() {
void
RequestScheduler::Start() {
if(!stopped_) {
return;
}
......@@ -120,7 +125,8 @@ void RequestScheduler::Start() {
stopped_ = false;
}
void RequestScheduler::Stop() {
void
RequestScheduler::Stop() {
if(stopped_) {
return;
}
......@@ -145,7 +151,8 @@ void RequestScheduler::Stop() {
SERVER_LOG_INFO << "Scheduler stopped";
}
ServerError RequestScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
ServerError
RequestScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
}
......@@ -188,7 +195,8 @@ namespace {
}
}
ServerError RequestScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
ServerError
RequestScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
std::lock_guard<std::mutex> lock(queue_mtx_);
std::string group_name = task_ptr->TaskGroup();
......
......@@ -23,20 +23,30 @@ protected:
virtual ~BaseTask();
public:
ServerError Execute();
ServerError WaitToFinish();
ServerError
Execute();
std::string TaskGroup() const { return task_group_; }
ServerError
WaitToFinish();
ServerError ErrorCode() const { return error_code_; }
std::string ErrorMsg() const { return error_msg_; }
std::string
TaskGroup() const { return task_group_; }
bool IsAsync() const { return async_; }
ServerError
ErrorCode() const { return error_code_; }
std::string
ErrorMsg() const { return error_msg_; }
bool
IsAsync() const { return async_; }
protected:
virtual ServerError OnExecute() = 0;
virtual ServerError
OnExecute() = 0;
ServerError SetError(ServerError error_code, const std::string& msg);
ServerError
SetError(ServerError error_code, const std::string& msg);
protected:
mutable std::mutex finish_mtx_;
......@@ -64,15 +74,18 @@ public:
void Start();
void Stop();
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
ServerError
ExecuteTask(const BaseTaskPtr& task_ptr);
static void ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status* grpc_status);
static void
ExecTask(BaseTaskPtr& task_ptr, ::milvus::grpc::Status* grpc_status);
protected:
RequestScheduler();
virtual ~RequestScheduler();
ServerError PutTaskToQueue(const BaseTaskPtr& task_ptr);
ServerError
PutTaskToQueue(const BaseTaskPtr& task_ptr);
private:
mutable std::mutex queue_mtx_;
......
......@@ -105,11 +105,13 @@ CreateTableTask::CreateTableTask(const ::milvus::grpc::TableSchema& schema)
}
BaseTaskPtr CreateTableTask::Create(const ::milvus::grpc::TableSchema& schema) {
BaseTaskPtr
CreateTableTask::Create(const ::milvus::grpc::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
}
ServerError CreateTableTask::OnExecute() {
ServerError
CreateTableTask::OnExecute() {
TimeRecorder rc("CreateTableTask");
try {
......@@ -159,11 +161,13 @@ DescribeTableTask::DescribeTableTask(const std::string &table_name, ::milvus::gr
schema_(schema) {
}
BaseTaskPtr DescribeTableTask::Create(const std::string& table_name, ::milvus::grpc::TableSchema& schema) {
BaseTaskPtr
DescribeTableTask::Create(const std::string& table_name, ::milvus::grpc::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema));
}
ServerError DescribeTableTask::OnExecute() {
ServerError
DescribeTableTask::OnExecute() {
TimeRecorder rc("DescribeTableTask");
try {
......@@ -202,11 +206,13 @@ BuildIndexTask::BuildIndexTask(const std::string& table_name)
table_name_(table_name) {
}
BaseTaskPtr BuildIndexTask::Create(const std::string& table_name) {
BaseTaskPtr
BuildIndexTask::Create(const std::string& table_name) {
return std::shared_ptr<BaseTask>(new BuildIndexTask(table_name));
}
ServerError BuildIndexTask::OnExecute() {
ServerError
BuildIndexTask::OnExecute() {
try {
TimeRecorder rc("BuildIndexTask");
......@@ -248,11 +254,13 @@ HasTableTask::HasTableTask(const std::string& table_name, bool& has_table)
}
BaseTaskPtr HasTableTask::Create(const std::string& table_name, bool& has_table) {
BaseTaskPtr
HasTableTask::Create(const std::string& table_name, bool& has_table) {
return std::shared_ptr<BaseTask>(new HasTableTask(table_name, has_table));
}
ServerError HasTableTask::OnExecute() {
ServerError
HasTableTask::OnExecute() {
try {
TimeRecorder rc("HasTableTask");
......@@ -283,11 +291,13 @@ DropTableTask::DropTableTask(const std::string& table_name)
}
BaseTaskPtr DropTableTask::Create(const std::string& table_name) {
BaseTaskPtr
DropTableTask::Create(const std::string& table_name) {
return std::shared_ptr<BaseTask>(new DropTableTask(table_name));
}
ServerError DropTableTask::OnExecute() {
ServerError
DropTableTask::OnExecute() {
try {
TimeRecorder rc("DropTableTask");
......@@ -333,11 +343,13 @@ ShowTablesTask::ShowTablesTask(::grpc::ServerWriter< ::milvus::grpc::TableName>&
}
BaseTaskPtr ShowTablesTask::Create(::grpc::ServerWriter< ::milvus::grpc::TableName>& writer) {
BaseTaskPtr
ShowTablesTask::Create(::grpc::ServerWriter< ::milvus::grpc::TableName>& writer) {
return std::shared_ptr<BaseTask>(new ShowTablesTask(writer));
}
ServerError ShowTablesTask::OnExecute() {
ServerError
ShowTablesTask::OnExecute() {
std::vector<engine::meta::TableSchema> schema_array;
engine::Status stat = DBWrapper::DB()->AllTables(schema_array);
if(!stat.ok()) {
......@@ -363,12 +375,14 @@ InsertVectorTask::InsertVectorTask(const ::milvus::grpc::InsertInfos& insert_inf
record_ids_.Clear();
}
BaseTaskPtr InsertVectorTask::Create(const ::milvus::grpc::InsertInfos& insert_infos,
BaseTaskPtr
InsertVectorTask::Create(const ::milvus::grpc::InsertInfos& insert_infos,
::milvus::grpc::VectorIds& record_ids) {
return std::shared_ptr<BaseTask>(new InsertVectorTask(insert_infos, record_ids));
}
ServerError InsertVectorTask::OnExecute() {
ServerError
InsertVectorTask::OnExecute() {
try {
TimeRecorder rc("InsertVectorTask");
......@@ -468,14 +482,16 @@ SearchVectorTask::SearchVectorTask(const ::milvus::grpc::SearchVectorInfos& sear
}
BaseTaskPtr SearchVectorTask::Create(const ::milvus::grpc::SearchVectorInfos& search_vector_infos,
BaseTaskPtr
SearchVectorTask::Create(const ::milvus::grpc::SearchVectorInfos& search_vector_infos,
const std::vector<std::string>& file_id_array,
::grpc::ServerWriter<::milvus::grpc::TopKQueryResult>& writer) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(search_vector_infos, file_id_array,
writer));
}
ServerError SearchVectorTask::OnExecute() {
ServerError
SearchVectorTask::OnExecute() {
try {
TimeRecorder rc("SearchVectorTask");
......@@ -618,11 +634,13 @@ GetTableRowCountTask::GetTableRowCountTask(const std::string& table_name, int64_
}
BaseTaskPtr GetTableRowCountTask::Create(const std::string& table_name, int64_t& row_count) {
BaseTaskPtr
GetTableRowCountTask::Create(const std::string& table_name, int64_t& row_count) {
return std::shared_ptr<BaseTask>(new GetTableRowCountTask(table_name, row_count));
}
ServerError GetTableRowCountTask::OnExecute() {
ServerError
GetTableRowCountTask::OnExecute() {
try {
TimeRecorder rc("GetTableRowCountTask");
......@@ -659,15 +677,15 @@ PingTask::PingTask(const std::string& cmd, std::string& result)
}
BaseTaskPtr PingTask::Create(const std::string& cmd, std::string& result) {
BaseTaskPtr
PingTask::Create(const std::string& cmd, std::string& result) {
return std::shared_ptr<BaseTask>(new PingTask(cmd, result));
}
ServerError PingTask::OnExecute() {
ServerError
PingTask::OnExecute() {
if(cmd_ == "version") {
result_ = MILVUS_VERSION;
} else if (cmd_ == "disconnect") {
//TODO stopservice
} else {
result_ = "OK";
}
......
......@@ -21,12 +21,15 @@ namespace server {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CreateTableTask : public BaseTask {
public:
static BaseTaskPtr Create(const ::milvus::grpc::TableSchema& schema);
static BaseTaskPtr
Create(const ::milvus::grpc::TableSchema& schema);
protected:
explicit CreateTableTask(const ::milvus::grpc::TableSchema& request);
explicit
CreateTableTask(const ::milvus::grpc::TableSchema& request);
ServerError OnExecute() override;
ServerError
OnExecute() override;
private:
const ::milvus::grpc::TableSchema schema_;
......
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
#define _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
#include "inc/Socket/Common.h"
#include "AggregatorSettings.h"
#include <memory>
#include <vector>
#include <atomic>
namespace SPTAG
{
namespace Aggregator
{
enum RemoteMachineStatus : uint8_t
{
Disconnected = 0,
Connecting,
Connected
};
struct RemoteMachine
{
RemoteMachine();
std::string m_address;
std::string m_port;
Socket::ConnectionID m_connectionID;
std::atomic<RemoteMachineStatus> m_status;
};
class AggregatorContext
{
public:
AggregatorContext(const std::string& p_filePath);
~AggregatorContext();
bool IsInitialized() const;
const std::vector<std::shared_ptr<RemoteMachine>>& GetRemoteServers() const;
const std::shared_ptr<AggregatorSettings>& GetSettings() const;
private:
std::vector<std::shared_ptr<RemoteMachine>> m_remoteServers;
std::shared_ptr<AggregatorSettings> m_settings;
bool m_initialized;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
#define _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
#include "inc/Socket/RemoteSearchQuery.h"
#include "inc/Socket/Packet.h"
#include <memory>
#include <atomic>
namespace SPTAG
{
namespace Aggregator
{
typedef std::shared_ptr<Socket::RemoteSearchResult> AggregatorResult;
class AggregatorExecutionContext
{
public:
AggregatorExecutionContext(std::size_t p_totalServerNumber,
Socket::PacketHeader p_requestHeader);
~AggregatorExecutionContext();
std::size_t GetServerNumber() const;
AggregatorResult& GetResult(std::size_t p_num);
const Socket::PacketHeader& GetRequestHeader() const;
bool IsCompletedAfterFinsh(std::uint32_t p_finishedCount);
private:
std::atomic<std::uint32_t> m_unfinishedCount;
std::vector<AggregatorResult> m_results;
Socket::PacketHeader m_requestHeader;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
#define _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
#include "AggregatorContext.h"
#include "AggregatorExecutionContext.h"
#include "inc/Socket/Server.h"
#include "inc/Socket/Client.h"
#include "inc/Socket/ResourceManager.h"
#include <boost/asio.hpp>
#include <memory>
#include <vector>
#include <thread>
#include <condition_variable>
namespace SPTAG
{
namespace Aggregator
{
class AggregatorService
{
public:
AggregatorService();
~AggregatorService();
bool Initialize();
void Run();
private:
void StartClient();
void StartListen();
void WaitForShutdown();
void ConnectToPendingServers();
void AddToPendingServers(std::shared_ptr<RemoteMachine> p_remoteServer);
void SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void AggregateResults(std::shared_ptr<AggregatorExecutionContext> p_exectionContext);
std::shared_ptr<AggregatorContext> GetContext();
private:
typedef std::function<void(Socket::RemoteSearchResult)> AggregatorCallback;
std::shared_ptr<AggregatorContext> m_aggregatorContext;
std::shared_ptr<Socket::Server> m_socketServer;
std::shared_ptr<Socket::Client> m_socketClient;
bool m_initalized;
std::unique_ptr<boost::asio::thread_pool> m_threadPool;
boost::asio::io_context m_ioContext;
boost::asio::signal_set m_shutdownSignals;
std::vector<std::shared_ptr<RemoteMachine>> m_pendingConnectServers;
std::mutex m_pendingConnectServersMutex;
boost::asio::deadline_timer m_pendingConnectServersTimer;
Socket::ResourceManager<AggregatorCallback> m_aggregatorCallbackManager;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
#define _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
#include "../Core/Common.h"
#include <string>
namespace SPTAG
{
namespace Aggregator
{
struct AggregatorSettings
{
AggregatorSettings();
std::string m_listenAddr;
std::string m_listenPort;
std::uint32_t m_searchTimeout;
SizeType m_threadNum;
SizeType m_socketThreadNum;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CLIENT_CLIENTWRAPPER_H_
#define _SPTAG_CLIENT_CLIENTWRAPPER_H_
#include "inc/Socket/Client.h"
#include "inc/Socket/RemoteSearchQuery.h"
#include "inc/Socket/ResourceManager.h"
#include "Options.h"
#include <string>
#include <vector>
#include <memory>
#include <atomic>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
namespace SPTAG
{
namespace Client
{
class ClientWrapper
{
public:
typedef std::function<void(Socket::RemoteSearchResult)> Callback;
ClientWrapper(const ClientOptions& p_options);
~ClientWrapper();
void SendQueryAsync(const Socket::RemoteQuery& p_query,
Callback p_callback,
const ClientOptions& p_options);
void WaitAllFinished();
bool IsAvailable() const;
private:
typedef std::pair<Socket::ConnectionID, Socket::ConnectionID> ConnectionPair;
Socket::PacketHandlerMapPtr GetHandlerMap();
void DecreaseUnfnishedJobCount();
const ConnectionPair& GetConnection();
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void HandleDeadConnection(Socket::ConnectionID p_cid);
private:
ClientOptions m_options;
std::unique_ptr<Socket::Client> m_client;
std::atomic<std::uint32_t> m_unfinishedJobCount;
std::atomic_bool m_isWaitingFinish;
std::condition_variable m_waitingQueue;
std::mutex m_waitingMutex;
std::vector<ConnectionPair> m_connections;
std::atomic<std::uint32_t> m_spinCountOfConnection;
Socket::ResourceManager<Callback> m_callbackManager;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_CLIENT_OPTIONS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CLIENT_OPTIONS_H_
#define _SPTAG_CLIENT_OPTIONS_H_
#include "inc/Helper/ArgumentsParser.h"
#include <string>
#include <vector>
#include <memory>
namespace SPTAG
{
namespace Client
{
class ClientOptions : public Helper::ArgumentsParser
{
public:
ClientOptions();
virtual ~ClientOptions();
std::string m_serverAddr;
std::string m_serverPort;
// in milliseconds.
std::uint32_t m_searchTimeout;
std::uint32_t m_threadNum;
std::uint32_t m_socketThreadNum;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_CLIENT_OPTIONS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_BKT_INDEX_H_
#define _SPTAG_BKT_INDEX_H_
#include "../Common.h"
#include "../VectorIndex.h"
#include "../Common/CommonUtils.h"
#include "../Common/DistanceUtils.h"
#include "../Common/QueryResultSet.h"
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/BKTree.h"
#include "inc/Helper/SimpleIniReader.h"
#include "inc/Helper/StringConvert.h"
#include <functional>
#include <mutex>
#include <tbb/concurrent_unordered_set.h>
namespace SPTAG
{
namespace Helper
{
class IniReader;
}
namespace BKT
{
template<typename T>
class Index : public VectorIndex
{
private:
// data points
COMMON::Dataset<T> m_pSamples;
// BKT structures.
COMMON::BKTree m_pTrees;
// Graph structure
COMMON::RelativeNeighborhoodGraph m_pGraph;
std::string m_sBKTFilename;
std::string m_sGraphFilename;
std::string m_sDataPointsFilename;
std::mutex m_dataLock; // protect data and graph
tbb::concurrent_unordered_set<int> m_deletedID;
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
int m_iMaxCheck;
int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
int m_iNumberOfInitialDynamicPivots;
int m_iNumberOfOtherDynamicPivots;
public:
Index()
{
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
VarName = DefaultValue; \
#include "inc/Core/BKT/ParameterDefinitionList.h"
#undef DefineBKTParameter
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
}
~Index() {}
inline int GetNumSamples() const { return m_pSamples.R(); }
inline int GetFeatureDim() const { return m_pSamples.C(); }
inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
inline int GetNumThreads() const { return m_iNumberOfThreads; }
inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::BKT; }
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; }
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen);
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs);
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout);
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader);
ErrorCode SearchIndex(QueryResult &p_query) const;
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const;
private:
ErrorCode RefineIndex(const std::string& p_folderPath);
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const;
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
};
} // namespace BKT
} // namespace SPTAG
#endif // _SPTAG_BKT_INDEX_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DefineBKTParameter
// DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr)
DefineBKTParameter(m_sBKTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
DefineBKTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
DefineBKTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
DefineBKTParameter(m_pTrees.m_iTreeNumber, int, 1L, "BKTNumber")
DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK")
DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize")
DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples")
DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TpTreeNumber")
DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
DefineBKTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTpTreeSplit")
DefineBKTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize")
DefineBKTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
DefineBKTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
DefineBKTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
DefineBKTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF")
DefineBKTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph")
DefineBKTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
DefineBKTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
DefineBKTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
DefineBKTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
DefineBKTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
DefineBKTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots")
#endif
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CORE_COMMONDEFS_H_
#define _SPTAG_CORE_COMMONDEFS_H_
#include <cstdint>
#include <type_traits>
#include <memory>
#include <string>
#include <limits>
#include <vector>
#include <cmath>
#ifndef _MSC_VER
#include <sys/stat.h>
#include <sys/types.h>
#define FolderSep '/'
#define mkdir(a) mkdir(a, ACCESSPERMS)
inline bool direxists(const char* path) {
struct stat info;
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR);
}
inline bool fileexists(const char* path) {
struct stat info;
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR) == 0;
}
template <class T>
inline T min(T a, T b) {
return a < b ? a : b;
}
template <class T>
inline T max(T a, T b) {
return a > b ? a : b;
}
#ifndef _rotl
#define _rotl(x, n) (((x) << (n)) | ((x) >> (32-(n))))
#endif
#else
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include <Psapi.h>
#define FolderSep '\\'
#define mkdir(a) CreateDirectory(a, NULL)
inline bool direxists(const char* path) {
auto dwAttr = GetFileAttributes((LPCSTR)path);
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY);
}
inline bool fileexists(const char* path) {
auto dwAttr = GetFileAttributes((LPCSTR)path);
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY) == 0;
}
#endif
namespace SPTAG
{
typedef std::uint32_t SizeType;
const float MinDist = (std::numeric_limits<float>::min)();
const float MaxDist = (std::numeric_limits<float>::max)();
const float Epsilon = 0.000000001f;
class MyException : public std::exception
{
private:
std::string Exp;
public:
MyException(std::string e) { Exp = e; }
#ifdef _MSC_VER
const char* what() const { return Exp.c_str(); }
#else
const char* what() const noexcept { return Exp.c_str(); }
#endif
};
// Type of number index.
typedef std::int32_t IndexType;
static_assert(std::is_integral<IndexType>::value, "IndexType must be integral type.");
enum class ErrorCode : std::uint16_t
{
#define DefineErrorCode(Name, Value) Name = Value,
#include "DefinitionList.h"
#undef DefineErrorCode
Undefined
};
static_assert(static_cast<std::uint16_t>(ErrorCode::Undefined) != 0, "Empty ErrorCode!");
enum class DistCalcMethod : std::uint8_t
{
#define DefineDistCalcMethod(Name) Name,
#include "DefinitionList.h"
#undef DefineDistCalcMethod
Undefined
};
static_assert(static_cast<std::uint8_t>(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!");
enum class VectorValueType : std::uint8_t
{
#define DefineVectorValueType(Name, Type) Name,
#include "DefinitionList.h"
#undef DefineVectorValueType
Undefined
};
static_assert(static_cast<std::uint8_t>(VectorValueType::Undefined) != 0, "Empty VectorValueType!");
enum class IndexAlgoType : std::uint8_t
{
#define DefineIndexAlgo(Name) Name,
#include "DefinitionList.h"
#undef DefineIndexAlgo
Undefined
};
static_assert(static_cast<std::uint8_t>(IndexAlgoType::Undefined) != 0, "Empty IndexAlgoType!");
template<typename T>
constexpr VectorValueType GetEnumValueType()
{
return VectorValueType::Undefined;
}
#define DefineVectorValueType(Name, Type) \
template<> \
constexpr VectorValueType GetEnumValueType<Type>() \
{ \
return VectorValueType::Name; \
} \
#include "DefinitionList.h"
#undef DefineVectorValueType
inline std::size_t GetValueTypeSize(VectorValueType p_valueType)
{
switch (p_valueType)
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
return sizeof(Type); \
#include "DefinitionList.h"
#undef DefineVectorValueType
default:
break;
}
return 0;
}
} // namespace SPTAG
#endif // _SPTAG_CORE_COMMONDEFS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_COMMONUTILS_H_
#define _SPTAG_COMMON_COMMONUTILS_H_
#include "../Common.h"
#include <unordered_map>
#include <fstream>
#include <iostream>
#include <exception>
#include <algorithm>
#include <time.h>
#include <omp.h>
#include <string.h>
#define PREFETCH
#ifndef _MSC_VER
#include <stdio.h>
#include <unistd.h>
#include <sys/resource.h>
#include <cstring>
#define InterlockedCompareExchange(a,b,c) __sync_val_compare_and_swap(a, c, b)
#define Sleep(a) usleep(a * 1000)
#define strtok_s(a, b, c) strtok_r(a, b, c)
#endif
namespace SPTAG
{
namespace COMMON
{
class Utils {
public:
static int rand_int(int high = RAND_MAX, int low = 0) // Generates a random int value.
{
return low + (int)(float(high - low)*(std::rand() / (RAND_MAX + 1.0)));
}
static inline float atomic_float_add(volatile float* ptr, const float operand)
{
union {
volatile long iOld;
float fOld;
};
union {
long iNew;
float fNew;
};
while (true) {
iOld = *(volatile long *)ptr;
fNew = fOld + operand;
if (InterlockedCompareExchange((long *)ptr, iNew, iOld) == iOld) {
return fNew;
}
}
}
static double GetVector(char* cstr, const char* sep, std::vector<float>& arr, int& NumDim) {
char* current;
char* context = NULL;
int i = 0;
double sum = 0;
arr.clear();
current = strtok_s(cstr, sep, &context);
while (current != NULL && (i < NumDim || NumDim < 0)) {
try {
float val = (float)atof(current);
arr.push_back(val);
}
catch (std::exception e) {
std::cout << "Exception:" << e.what() << std::endl;
return -2;
}
sum += arr[i] * arr[i];
current = strtok_s(NULL, sep, &context);
i++;
}
if (NumDim < 0) NumDim = i;
if (i < NumDim) return -2;
return std::sqrt(sum);
}
template <typename T>
static void Normalize(T* arr, int col, int base) {
double vecLen = 0;
for (int j = 0; j < col; j++) {
double val = arr[j];
vecLen += val * val;
}
vecLen = std::sqrt(vecLen);
if (vecLen < 1e-6) {
T val = (T)(1.0 / std::sqrt((double)col) * base);
for (int j = 0; j < col; j++) arr[j] = val;
}
else {
for (int j = 0; j < col; j++) arr[j] = (T)(arr[j] / vecLen * base);
}
}
static size_t ProcessLine(std::string& currentLine, std::vector<float>& arr, int& D, int base, DistCalcMethod distCalcMethod) {
size_t index;
double vecLen;
if (currentLine.length() == 0 || (index = currentLine.find_last_of("\t")) == std::string::npos || (vecLen = GetVector(const_cast<char*>(currentLine.c_str() + index + 1), "|", arr, D)) < -1) {
std::cout << "Parse vector error: " + currentLine << std::endl;
//throw MyException("Error in parsing data " + currentLine);
return -1;
}
if (distCalcMethod == DistCalcMethod::Cosine) {
Normalize(arr.data(), D, base);
}
return index;
}
template <typename T>
static void PrepareQuerys(std::ifstream& inStream, std::vector<std::string>& qString, std::vector<std::vector<T>>& Query, int& NumQuery, int& NumDim, DistCalcMethod distCalcMethod, int base) {
std::string currentLine;
std::vector<float> arr;
int i = 0;
size_t index;
while ((NumQuery < 0 || i < NumQuery) && !inStream.eof()) {
std::getline(inStream, currentLine);
if (currentLine.length() <= 1 || (index = ProcessLine(currentLine, arr, NumDim, base, distCalcMethod)) < 0) {
continue;
}
qString.push_back(currentLine.substr(0, index));
if (Query.size() < i + 1) Query.push_back(std::vector<T>(NumDim, 0));
for (int j = 0; j < NumDim; j++) Query[i][j] = (T)arr[j];
i++;
}
NumQuery = i;
std::cout << "Load data: (" << NumQuery << ", " << NumDim << ")" << std::endl;
}
template<typename T>
static inline int GetBase() {
if (GetEnumValueType<T>() != VectorValueType::Float) {
return (int)(std::numeric_limits<T>::max)();
}
return 1;
}
static inline void AddNeighbor(int idx, float dist, int *neighbors, float *dists, int size)
{
size--;
if (dist < dists[size] || (dist == dists[size] && idx < neighbors[size]))
{
int nb;
for (nb = 0; nb <= size && neighbors[nb] != idx; nb++);
if (nb > size)
{
nb = size;
while (nb > 0 && (dist < dists[nb - 1] || (dist == dists[nb - 1] && idx < neighbors[nb - 1])))
{
dists[nb] = dists[nb - 1];
neighbors[nb] = neighbors[nb - 1];
nb--;
}
dists[nb] = dist;
neighbors[nb] = idx;
}
}
}
};
}
}
#endif // _SPTAG_COMMON_COMMONUTILS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_DATAUTILS_H_
#define _SPTAG_COMMON_DATAUTILS_H_
#include <sys/stat.h>
#include <atomic>
#include "CommonUtils.h"
#include "../../Helper/CommonHelper.h"
namespace SPTAG
{
namespace COMMON
{
const int bufsize = 1024 * 1024 * 1024;
class DataUtils {
public:
template <typename T>
static void ProcessTSVData(int id, int threadbase, std::uint64_t blocksize,
std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) {
std::ifstream inputStream(filename);
if (!inputStream.is_open()) {
std::cerr << "unable to open file " + filename << std::endl;
throw MyException("unable to open file " + filename);
exit(1);
}
std::ofstream outputStream, metaStream_out, metaStream_index;
outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary);
if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) {
std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl;
throw MyException("unable to open output files");
exit(1);
}
std::vector<float> arr;
std::vector<T> sample;
int base = 1;
if (distCalcMethod == DistCalcMethod::Cosine) {
base = Utils::GetBase<T>();
}
std::uint64_t writepos = 0;
int sampleSize = 0;
std::uint64_t totalread = 0;
std::streamoff startpos = id * blocksize;
#ifndef _MSC_VER
int enter_size = 1;
#else
int enter_size = 1;
#endif
std::string currentLine;
size_t index;
inputStream.seekg(startpos, std::ifstream::beg);
if (id != 0) {
std::getline(inputStream, currentLine);
totalread += currentLine.length() + enter_size;
}
std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl;
while (!inputStream.eof() && totalread <= blocksize) {
std::getline(inputStream, currentLine);
if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) {
totalread += currentLine.length() + enter_size;
continue;
}
sample.resize(D);
for (int j = 0; j < D; j++) sample[j] = (T)arr[j];
outputStream.write((char *)(sample.data()), sizeof(T)*D);
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_out.write(currentLine.c_str(), index);
writepos += index;
sampleSize += 1;
totalread += currentLine.length() + enter_size;
}
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_index.write((char *)&sampleSize, sizeof(int));
inputStream.close();
outputStream.close();
metaStream_out.close();
metaStream_index.close();
numSamples.fetch_add(sampleSize);
std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl;
}
static void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int D) {
std::ifstream inputStream;
std::ofstream outputStream;
char * buf = new char[bufsize];
std::uint64_t * offsets;
int partSamples;
int metaSamples = 0;
std::uint64_t lastoff = 0;
outputStream.open(outfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
outputStream.write((char *)&D, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
outputStream.open(outmetafile, std::ofstream::binary);
for (int i = 0; i < threadbase; i++) {
std::string file = outmetafile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
delete[] buf;
outputStream.open(outmetaindexfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outmetaindexfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
inputStream.seekg(-((long long)sizeof(int)), inputStream.end);
inputStream.read((char *)&partSamples, sizeof(int));
offsets = new std::uint64_t[partSamples + 1];
inputStream.seekg(0, inputStream.beg);
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1));
inputStream.close();
remove(file.c_str());
for (int j = 0; j < partSamples + 1; j++)
offsets[j] += lastoff;
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples);
lastoff = offsets[partSamples];
metaSamples += partSamples;
delete[] offsets;
}
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
outputStream.close();
std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl;
}
static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1,
const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) {
std::ifstream inputStream1, inputStream2;
std::ofstream outputStream;
char * buf = new char[bufsize];
int R1, R2, C1, C2;
#define MergeVector(inputStream, vectorFile, R, C) \
inputStream.open(vectorFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \
return false; \
} \
inputStream.read((char *)&(R), sizeof(int)); \
inputStream.read((char *)&(C), sizeof(int)); \
MergeVector(inputStream1, p_vectorfile1, R1, C1)
MergeVector(inputStream2, p_vectorfile2, R2, C2)
#undef MergeVector
if (C1 != C2) {
inputStream1.close(); inputStream2.close();
std::cout << "Vector dimensions are not the same!" << std::endl;
return false;
}
R1 += R2;
outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int));
outputStream.write((char *)&C1, sizeof(int));
while (!inputStream1.eof()) {
inputStream1.read(buf, bufsize);
outputStream.write(buf, inputStream1.gcount());
}
while (!inputStream2.eof()) {
inputStream2.read(buf, bufsize);
outputStream.write(buf, inputStream2.gcount());
}
inputStream1.close(); inputStream2.close();
outputStream.close();
if (p_metafile1 != "" && p_metafile2 != "") {
outputStream.open(p_metafile1 + "_tmp", std::ofstream::binary);
#define MergeMeta(inputStream, metaFile) \
inputStream.open(metaFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open meta file: " << metaFile << "!" << std::endl; \
return false; \
} \
while (!inputStream.eof()) { \
inputStream.read(buf, bufsize); \
outputStream.write(buf, inputStream.gcount()); \
} \
inputStream.close(); \
MergeMeta(inputStream1, p_metafile1)
MergeMeta(inputStream2, p_metafile2)
#undef MergeMeta
outputStream.close();
delete[] buf;
std::uint64_t * offsets;
int partSamples;
std::uint64_t lastoff = 0;
outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int));
#define MergeMetaIndex(inputStream, metaIndexFile) \
inputStream.open(metaIndexFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \
return false; \
} \
inputStream.read((char *)&partSamples, sizeof(int)); \
offsets = new std::uint64_t[partSamples + 1]; \
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); \
inputStream.close(); \
for (int j = 0; j < partSamples + 1; j++) offsets[j] += lastoff; \
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); \
lastoff = offsets[partSamples]; \
delete[] offsets; \
MergeMetaIndex(inputStream1, p_metaindexfile1)
MergeMetaIndex(inputStream2, p_metaindexfile2)
#undef MergeMetaIndex
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
outputStream.close();
rename((p_metafile1 + "_tmp").c_str(), p_metafile1.c_str());
rename((p_metaindexfile1 + "_tmp").c_str(), p_metaindexfile1.c_str());
}
rename((p_vectorfile1 + "_tmp").c_str(), p_vectorfile1.c_str());
std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl;
return true;
}
template <typename T>
static void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
int threadnum, DistCalcMethod distCalcMethod) {
omp_set_num_threads(threadnum);
std::atomic_int numSamples = { 0 };
int D = -1;
int threadbase = 0;
std::vector<std::string> inputFileNames = Helper::StrUtils::SplitString(filenames, ",");
for (std::string inputFileName : inputFileNames)
{
#ifndef _MSC_VER
struct stat stat_buf;
stat(inputFileName.c_str(), &stat_buf);
#else
struct _stat64 stat_buf;
int res = _stat64(inputFileName.c_str(), &stat_buf);
#endif
std::uint64_t blocksize = (stat_buf.st_size + threadnum - 1) / threadnum;
#pragma omp parallel for
for (int i = 0; i < threadnum; i++) {
ProcessTSVData<T>(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod);
}
threadbase += threadnum;
}
MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D);
}
};
}
}
#endif // _SPTAG_COMMON_DATAUTILS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_DATASET_H_
#define _SPTAG_COMMON_DATASET_H_
#include <fstream>
#if defined(_MSC_VER) || defined(__INTEL_COMPILER)
#include <malloc.h>
#else
#include <mm_malloc.h>
#endif // defined(__GNUC__)
#define ALIGN 32
#define aligned_malloc(a, b) _mm_malloc(a, b)
#define aligned_free(a) _mm_free(a)
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
namespace SPTAG
{
namespace COMMON
{
// structure to save Data and Graph
template <typename T>
class Dataset
{
private:
int rows;
int cols;
bool ownData = false;
T* data = nullptr;
std::vector<T> dataIncremental;
public:
Dataset(): rows(0), cols(1) {}
Dataset(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{
Initialize(rows_, cols_, data_, transferOnwership_);
}
~Dataset()
{
if (ownData) aligned_free(data);
}
void Initialize(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{
rows = rows_;
cols = cols_;
data = data_;
if (data_ == nullptr || !transferOnwership_)
{
ownData = true;
data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN);
if (data_ != nullptr) memcpy(data, data_, rows * cols * sizeof(T));
else std::memset(data, -1, rows * cols * sizeof(T));
}
}
void SetR(int R_)
{
if (R_ >= rows)
dataIncremental.resize((R_ - rows) * cols);
else
{
rows = R_;
dataIncremental.clear();
}
}
inline int R() const { return (int)(rows + dataIncremental.size() / cols); }
inline int C() const { return cols; }
T* operator[](int index)
{
if (index >= rows) {
return dataIncremental.data() + (size_t)(index - rows)*cols;
}
return data + (size_t)index*cols;
}
const T* operator[](int index) const
{
if (index >= rows) {
return dataIncremental.data() + (size_t)(index - rows)*cols;
}
return data + (size_t)index*cols;
}
void AddBatch(const T* pData, int num)
{
dataIncremental.insert(dataIncremental.end(), pData, pData + num*cols);
}
void AddBatch(int num)
{
dataIncremental.insert(dataIncremental.end(), (size_t)num*cols, T(-1));
}
bool Save(std::string sDataPointsFileName)
{
std::cout << "Save Data To " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
if (fp == NULL) return false;
int CR = R();
fwrite(&CR, sizeof(int), 1, fp);
fwrite(&cols, sizeof(int), 1, fp);
T* ptr = data;
int toWrite = rows;
while (toWrite > 0)
{
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
ptr += write * cols;
toWrite -= (int)write;
}
ptr = dataIncremental.data();
toWrite = CR - rows;
while (toWrite > 0)
{
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
ptr += write * cols;
toWrite -= (int)write;
}
fclose(fp);
std::cout << "Save Data (" << CR << ", " << cols << ") Finish!" << std::endl;
return true;
}
bool Save(void **pDataPointsMemFile, int64_t &len)
{
size_t size = sizeof(int) + sizeof(int) + sizeof(T) * R() *cols;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
int CR = R();
auto header = (int*)mem;
header[0] = CR;
header[1] = cols;
auto body = &mem[8];
memcpy(body, data, sizeof(T) * cols * rows);
body += sizeof(T) * cols * rows;
memcpy(body, dataIncremental.data(), sizeof(T) * cols * (CR - rows));
body += sizeof(T) * cols * (CR - rows);
*pDataPointsMemFile = mem;
len = size;
return true;
}
bool Load(std::string sDataPointsFileName)
{
std::cout << "Load Data From " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
if (fp == NULL) return false;
int R, C;
fread(&R, sizeof(int), 1, fp);
fread(&C, sizeof(int), 1, fp);
Initialize(R, C);
T* ptr = data;
while (R > 0) {
size_t read = fread(ptr, sizeof(T) * C, R, fp);
ptr += read * C;
R -= (int)read;
}
fclose(fp);
std::cout << "Load Data (" << rows << ", " << cols << ") Finish!" << std::endl;
return true;
}
// Functions for loading models from memory mapped files
bool Load(char* pDataPointsMemFile)
{
int R, C;
R = *((int*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(int);
C = *((int*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(int);
Initialize(R, C, (T*)pDataPointsMemFile);
return true;
}
bool Refine(const std::vector<int>& indices, std::string sDataPointsFileName)
{
std::cout << "Save Refine Data To " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
if (fp == NULL) return false;
int R = (int)(indices.size());
fwrite(&R, sizeof(int), 1, fp);
fwrite(&cols, sizeof(int), 1, fp);
// write point one by one in case for cache miss
for (int i = 0; i < R; i++) {
if (indices[i] < rows)
fwrite(data + (size_t)indices[i] * cols, sizeof(T) * cols, 1, fp);
else
fwrite(dataIncremental.data() + (size_t)(indices[i] - rows) * cols, sizeof(T) * cols, 1, fp);
}
fclose(fp);
std::cout << "Save Refine Data (" << R << ", " << cols << ") Finish!" << std::endl;
return true;
}
};
}
}
#endif // _SPTAG_COMMON_DATASET_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_FINEGRAINEDLOCK_H_
#define _SPTAG_COMMON_FINEGRAINEDLOCK_H_
#include <vector>
#include <mutex>
#include <memory>
namespace SPTAG
{
namespace COMMON
{
class FineGrainedLock {
public:
FineGrainedLock() {}
~FineGrainedLock() {
for (int i = 0; i < locks.size(); i++)
locks[i].reset();
locks.clear();
}
void resize(int n) {
int current = (int)locks.size();
if (current <= n) {
locks.resize(n);
for (int i = current; i < n; i++)
locks[i].reset(new std::mutex);
}
else {
for (int i = n; i < current; i++)
locks[i].reset();
locks.resize(n);
}
}
std::mutex& operator[](int idx) {
return *locks[idx];
}
const std::mutex& operator[](int idx) const {
return *locks[idx];
}
private:
std::vector<std::shared_ptr<std::mutex>> locks;
};
}
}
#endif // _SPTAG_COMMON_FINEGRAINEDLOCK_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_HEAP_H_
#define _SPTAG_COMMON_HEAP_H_
namespace SPTAG
{
namespace COMMON
{
// priority queue
template <typename T>
class Heap {
public:
Heap() : heap(nullptr), length(0), count(0) {}
Heap(int size) { Resize(size); }
void Resize(int size)
{
length = size;
heap.reset(new T[length + 1]); // heap uses 1-based indexing
count = 0;
lastlevel = int(pow(2.0, floor(log2(size))));
}
~Heap() {}
inline int size() { return count; }
inline bool empty() { return count == 0; }
inline void clear() { count = 0; }
inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; }
// Insert a new element in the heap.
void insert(T value)
{
/* If heap is full, then return without adding this element. */
int loc;
if (count == length) {
int maxi = lastlevel;
for (int i = lastlevel + 1; i <= length; i++)
if (heap[maxi] < heap[i]) maxi = i;
if (value > heap[maxi]) return;
loc = maxi;
}
else {
loc = ++(count); /* Remember 1-based indexing. */
}
/* Keep moving parents down until a place is found for this node. */
int par = (loc >> 1); /* Location of parent. */
while (par > 0 && value < heap[par]) {
heap[loc] = heap[par]; /* Move parent down to loc. */
loc = par;
par >>= 1;
}
/* Insert the element at the determined location. */
heap[loc] = value;
}
// Returns the node of minimum value from the heap (top of the heap).
bool pop(T& value)
{
if (count == 0) return false;
/* Switch first node with last. */
value = heap[1];
std::swap(heap[1], heap[count]);
count--;
heapify(); /* Move new node 1 to right position. */
return true; /* Return old last node. */
}
T& pop()
{
if (count == 0) return heap[0];
/* Switch first node with last. */
std::swap(heap[1], heap[count]);
count--;
heapify(); /* Move new node 1 to right position. */
return heap[count + 1]; /* Return old last node. */
}
private:
// Storage array for the heap.
// Type T must be comparable.
std::unique_ptr<T[]> heap;
int length;
int count; // Number of element in the heap
int lastlevel;
// Reorganizes the heap (a parent is smaller than its children) starting with a node.
void heapify()
{
int parent = 1, next = 2;
while (next < count) {
if (heap[next] > heap[next + 1]) next++;
if (heap[next] < heap[parent]) {
std::swap(heap[parent], heap[next]);
parent = next;
next <<= 1;
}
else break;
}
if (next == count && heap[next] < heap[parent]) std::swap(heap[parent], heap[next]);
}
};
}
}
#endif // _SPTAG_COMMON_HEAP_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_KDTREE_H_
#define _SPTAG_COMMON_KDTREE_H_
#include <iostream>
#include <vector>
#include <string>
#include "../VectorIndex.h"
#include "CommonUtils.h"
#include "QueryResultSet.h"
#include "WorkSpace.h"
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
namespace SPTAG
{
namespace COMMON
{
// node type for storing KDT
struct KDTNode
{
int left;
int right;
short split_dim;
float split_value;
};
class KDTree
{
public:
KDTree() : m_iTreeNumber(2), m_numTopDimensionKDTSplit(5), m_iSamples(1000) {}
KDTree(KDTree& other) : m_iTreeNumber(other.m_iTreeNumber),
m_numTopDimensionKDTSplit(other.m_numTopDimensionKDTSplit),
m_iSamples(other.m_iSamples) {}
~KDTree() {}
inline const KDTNode& operator[](int index) const { return m_pTreeRoots[index]; }
inline KDTNode& operator[](int index) { return m_pTreeRoots[index]; }
inline int size() const { return (int)m_pTreeRoots.size(); }
template <typename T>
void BuildTrees(VectorIndex* p_index, std::vector<int>* indices = nullptr)
{
std::vector<int> localindices;
if (indices == nullptr) {
localindices.resize(p_index->GetNumSamples());
for (int i = 0; i < p_index->GetNumSamples(); i++) localindices[i] = i;
}
else {
localindices.assign(indices->begin(), indices->end());
}
m_pTreeRoots.resize(m_iTreeNumber * localindices.size());
m_pTreeStart.resize(m_iTreeNumber, 0);
#pragma omp parallel for
for (int i = 0; i < m_iTreeNumber; i++)
{
Sleep(i * 100); std::srand(clock());
std::vector<int> pindices(localindices.begin(), localindices.end());
std::random_shuffle(pindices.begin(), pindices.end());
m_pTreeStart[i] = i * (int)pindices.size();
std::cout << "Start to build KDTree " << i + 1 << std::endl;
int iTreeSize = m_pTreeStart[i];
DivideTree<T>(p_index, pindices, 0, (int)pindices.size() - 1, m_pTreeStart[i], iTreeSize);
std::cout << i + 1 << " KDTree built, " << iTreeSize - m_pTreeStart[i] << " " << pindices.size() << std::endl;
}
}
bool SaveTrees(void **pKDTMemFile, int64_t &len) const
{
int treeNodeSize = (int)m_pTreeRoots.size();
size_t size = sizeof(int) +
sizeof(int) * m_iTreeNumber +
sizeof(int) +
sizeof(KDTNode) * treeNodeSize;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
auto ptr = mem;
*(int*)ptr = m_iTreeNumber;
ptr += sizeof(int);
memcpy(ptr, m_pTreeStart.data(), sizeof(int) * m_iTreeNumber);
ptr += sizeof(int) * m_iTreeNumber;
*(int*)ptr = treeNodeSize;
ptr += sizeof(int);
memcpy(ptr, m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize);
*pKDTMemFile = mem;
len = size;
return true;
}
bool SaveTrees(std::string sTreeFileName) const
{
std::cout << "Save KDT to " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "wb");
if (fp == NULL) return false;
fwrite(&m_iTreeNumber, sizeof(int), 1, fp);
fwrite(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize = (int)m_pTreeRoots.size();
fwrite(&treeNodeSize, sizeof(int), 1, fp);
fwrite(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Save KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
bool LoadTrees(char* pKDTMemFile)
{
m_iTreeNumber = *((int*)pKDTMemFile);
pKDTMemFile += sizeof(int);
m_pTreeStart.resize(m_iTreeNumber);
memcpy(m_pTreeStart.data(), pKDTMemFile, sizeof(int) * m_iTreeNumber);
pKDTMemFile += sizeof(int)*m_iTreeNumber;
int treeNodeSize = *((int*)pKDTMemFile);
pKDTMemFile += sizeof(int);
m_pTreeRoots.resize(treeNodeSize);
memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize);
return true;
}
bool LoadTrees(std::string sTreeFileName)
{
std::cout << "Load KDT From " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "rb");
if (fp == NULL) return false;
fread(&m_iTreeNumber, sizeof(int), 1, fp);
m_pTreeStart.resize(m_iTreeNumber);
fread(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize;
fread(&treeNodeSize, sizeof(int), 1, fp);
m_pTreeRoots.resize(treeNodeSize);
fread(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
template <typename T>
void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
{
for (char i = 0; i < m_iTreeNumber; i++) {
KDTSearch(p_index, p_query, p_space, m_pTreeStart[i], true, 0);
}
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
{
auto& tcell = p_space.m_SPTQueue.pop();
if (p_query.worstDist() < tcell.distance) break;
KDTSearch(p_index, p_query, p_space, tcell.node, true, tcell.distance);
}
}
template <typename T>
void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
{
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
{
auto& tcell = p_space.m_SPTQueue.pop();
KDTSearch(p_index, p_query, p_space, tcell.node, false, tcell.distance);
}
}
private:
template <typename T>
void KDTSearch(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query,
COMMON::WorkSpace& p_space, const int node, const bool isInit, const float distBound) const {
if (node < 0)
{
int index = -node - 1;
if (index >= p_index->GetNumSamples()) return;
#ifdef PREFETCH
const char* data = (const char *)(p_index->GetSample(index));
_mm_prefetch(data, _MM_HINT_T0);
_mm_prefetch(data + 64, _MM_HINT_T0);
#endif
if (p_space.CheckAndSet(index)) return;
++p_space.m_iNumberOfTreeCheckedLeaves;
++p_space.m_iNumberOfCheckedLeaves;
p_space.m_NGQueue.insert(COMMON::HeapCell(index, p_index->ComputeDistance((const void*)p_query.GetTarget(), (const void*)data)));
return;
}
auto& tnode = m_pTreeRoots[node];
float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value;
float distanceBound = distBound + diff * diff;
int otherChild, bestChild;
if (diff < 0)
{
bestChild = tnode.left;
otherChild = tnode.right;
}
else
{
otherChild = tnode.left;
bestChild = tnode.right;
}
if (!isInit || distanceBound < p_query.worstDist())
{
p_space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound));
}
KDTSearch(p_index, p_query, p_space, bestChild, isInit, distBound);
}
template <typename T>
void DivideTree(VectorIndex* p_index, std::vector<int>& indices, int first, int last,
int index, int &iTreeSize) {
ChooseDivision<T>(p_index, m_pTreeRoots[index], indices, first, last);
int i = Subdivide<T>(p_index, m_pTreeRoots[index], indices, first, last);
if (i - 1 <= first)
{
m_pTreeRoots[index].left = -indices[first] - 1;
}
else
{
iTreeSize++;
m_pTreeRoots[index].left = iTreeSize;
DivideTree<T>(p_index, indices, first, i - 1, iTreeSize, iTreeSize);
}
if (last == i)
{
m_pTreeRoots[index].right = -indices[last] - 1;
}
else
{
iTreeSize++;
m_pTreeRoots[index].right = iTreeSize;
DivideTree<T>(p_index, indices, i, last, iTreeSize, iTreeSize);
}
}
template <typename T>
void ChooseDivision(VectorIndex* p_index, KDTNode& node, const std::vector<int>& indices, const int first, const int last)
{
std::vector<float> meanValues(p_index->GetFeatureDim(), 0);
std::vector<float> varianceValues(p_index->GetFeatureDim(), 0);
int end = min(first + m_iSamples, last);
int count = end - first + 1;
// calculate the mean of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
meanValues[k] += v[k];
}
}
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
meanValues[k] /= count;
}
// calculate the variance of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
float dist = v[k] - meanValues[k];
varianceValues[k] += dist*dist;
}
}
// choose the split dimension as one of the dimension inside TOP_DIM maximum variance
node.split_dim = SelectDivisionDimension(varianceValues);
// determine the threshold
node.split_value = meanValues[node.split_dim];
}
int SelectDivisionDimension(const std::vector<float>& varianceValues) const
{
// Record the top maximum variances
std::vector<int> topind(m_numTopDimensionKDTSplit);
int num = 0;
// order the variances
for (int i = 0; i < varianceValues.size(); i++)
{
if (num < m_numTopDimensionKDTSplit || varianceValues[i] > varianceValues[topind[num - 1]])
{
if (num < m_numTopDimensionKDTSplit)
{
topind[num++] = i;
}
else
{
topind[num - 1] = i;
}
int j = num - 1;
// order the TOP_DIM variances
while (j > 0 && varianceValues[topind[j]] > varianceValues[topind[j - 1]])
{
std::swap(topind[j], topind[j - 1]);
j--;
}
}
}
// randomly choose a dimension from TOP_DIM
return topind[COMMON::Utils::rand_int(num)];
}
template <typename T>
int Subdivide(VectorIndex* p_index, const KDTNode& node, std::vector<int>& indices, const int first, const int last) const
{
int i = first;
int j = last;
// decide which child one point belongs
while (i <= j)
{
int ind = indices[i];
const T* v = (const T*)p_index->GetSample(ind);
float val = v[node.split_dim];
if (val < node.split_value)
{
i++;
}
else
{
std::swap(indices[i], indices[j]);
j--;
}
}
// if all the points in the node are equal,equally split the node into 2
if ((i == first) || (i == last + 1))
{
i = (first + last + 1) / 2;
}
return i;
}
private:
std::vector<int> m_pTreeStart;
std::vector<KDTNode> m_pTreeRoots;
public:
int m_iTreeNumber, m_numTopDimensionKDTSplit, m_iSamples;
};
}
}
#endif
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_WORKSPACEPOOL_H_
#define _SPTAG_COMMON_WORKSPACEPOOL_H_
#include "WorkSpace.h"
#include <list>
#include <mutex>
namespace SPTAG
{
namespace COMMON
{
class WorkSpacePool
{
public:
WorkSpacePool(int p_maxCheck, int p_vectorCount);
virtual ~WorkSpacePool();
std::shared_ptr<WorkSpace> Rent();
void Return(const std::shared_ptr<WorkSpace>& p_workSpace);
void Init(int size);
private:
std::list<std::shared_ptr<WorkSpace>> m_workSpacePool;
std::mutex m_workSpacePoolMutex;
int m_maxCheck;
int m_vectorCount;
};
}
}
#endif // _SPTAG_COMMON_WORKSPACEPOOL_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_INDEXBUILDER_THREADPOOL_H_
#define _SPTAG_INDEXBUILDER_THREADPOOL_H_
#include <functional>
#include <cstdint>
namespace SPTAG
{
namespace IndexBuilder
{
namespace ThreadPool
{
void Init(std::uint32_t p_threadNum);
bool Queue(std::function<void()> p_workItem);
std::uint32_t CurrentThreadNum();
}
} // namespace IndexBuilder
} // namespace SPTAG
#endif // _SPTAG_INDEXBUILDER_THREADPOOL_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册