提交 cfffcf9e 编写于 作者: Z zhiru

Merge remote-tracking branch 'upstream/branch-0.3.1' into branch-0.3.1


Former-commit-id: 4f3226d234fb9cbbab86cb7e85d65e97c5773e2c
......@@ -18,3 +18,4 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-161 - Add CI / CD Module to Milvus Project
- MS-202 - Add Milvus Jenkins project email notification
- MS-215 - Add Milvus cluster CI/CD groovy file
- MS-277 - Update CUDA Version to V10.1
......@@ -35,7 +35,7 @@ pipeline {
defaultContainer 'jnlp'
containerTemplate {
name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.11'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true
command 'cat'
}
......
......@@ -35,7 +35,7 @@ pipeline {
defaultContainer 'jnlp'
containerTemplate {
name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.11'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true
command 'cat'
}
......
......@@ -35,7 +35,7 @@ pipeline {
defaultContainer 'jnlp'
containerTemplate {
name 'milvus-build-env'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.11'
image 'registry.zilliz.com/milvus/milvus-build-env:v0.12'
ttyEnabled true
command 'cat'
}
......
......@@ -752,10 +752,7 @@ macro(build_faiss)
if(${MILVUS_WITH_FAISS_GPU_VERSION} STREQUAL "ON")
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS}
"--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}"
"--with-cuda-arch=\"-gencode=arch=compute_35,code=sm_35\""
"--with-cuda-arch=\"-gencode=arch=compute_52,code=sm_52\""
"--with-cuda-arch=\"-gencode=arch=compute_60,code=sm_60\""
"--with-cuda-arch=\"-gencode=arch=compute_61,code=sm_61\""
"--with-cuda-arch=-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_75,code=sm_75"
)
else()
set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} --without-cuda)
......@@ -769,7 +766,7 @@ macro(build_faiss)
"./configure"
${FAISS_CONFIGURE_ARGS}
BUILD_COMMAND
${MAKE} ${MAKE_BUILD_ARGS}
${MAKE} ${MAKE_BUILD_ARGS} VERBOSE=1
BUILD_IN_SOURCE
1
INSTALL_COMMAND
......@@ -1676,14 +1673,18 @@ macro(build_gperftools)
BUILD_BYPRODUCTS
${GPERFTOOLS_STATIC_LIB})
ExternalProject_Add_StepDependencies(gperftools_ep build libunwind_ep)
file(MAKE_DIRECTORY "${GPERFTOOLS_INCLUDE_DIR}")
add_library(gperftools SHARED IMPORTED)
add_library(gperftools STATIC IMPORTED)
set_target_properties(gperftools
PROPERTIES IMPORTED_LOCATION "${GPERFTOOLS_STATIC_LIB}"
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}")
INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES libunwind)
add_dependencies(gperftools gperftools_ep)
add_dependencies(gperftools libunwind_ep)
endmacro()
if(MILVUS_WITH_GPERFTOOLS)
......@@ -1692,4 +1693,5 @@ if(MILVUS_WITH_GPERFTOOLS)
# TODO: Don't use global includes but rather target_include_directories
get_target_property(GPERFTOOLS_INCLUDE_DIR gperftools INTERFACE_INCLUDE_DIRECTORIES)
include_directories(SYSTEM ${GPERFTOOLS_INCLUDE_DIR})
link_directories(SYSTEM ${GPERFTOOLS_PREFIX}/lib)
endif()
......@@ -8,6 +8,8 @@ db_config:
db_path: @MILVUS_DB_PATH@ # milvus data storage path
db_slave_path: # secondry data storage path, split by semicolon
parallel_reduce: false # use multi-threads to reduce topk result
# URI format: dialect://username:password@host:port/database
# All parts except dialect are optional, but you MUST include the delimiters
# Currently dialect supports mysql or sqlite
......
......@@ -63,10 +63,6 @@ include_directories("${CUDA_TOOLKIT_ROOT_DIR}/include")
include_directories(thrift/gen-cpp)
include_directories(/usr/include/mysql)
if (MILVUS_ENABLE_PROFILING STREQUAL "ON")
SET(PROFILER_LIB profiler)
endif()
set(third_party_libs
easyloggingpp
sqlite
......@@ -85,7 +81,6 @@ set(third_party_libs
zlib
zstd
mysqlpp
${PROFILER_LIB}
${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so
cudart
)
......@@ -103,6 +98,12 @@ else()
openblas)
endif()
if (MILVUS_ENABLE_PROFILING STREQUAL "ON")
set(third_party_libs ${third_party_libs}
gperftools
libunwind)
endif()
if (GPU_VERSION STREQUAL "ON")
link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
set(engine_libs
......
......@@ -89,7 +89,7 @@ void Cache::erase(const std::string& key) {
const DataObjPtr& data_ptr = obj_ptr->data_;
usage_ -= data_ptr->size();
SERVER_LOG_DEBUG << "Erase " << key << " from cache";
SERVER_LOG_DEBUG << "Erase " << key << " size: " << data_ptr->size();
lru_.erase(key);
}
......
......@@ -4,6 +4,7 @@
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include "utils/Log.h"
#include "CacheMgr.h"
#include "metrics/Metrics.h"
......@@ -20,6 +21,7 @@ CacheMgr::~CacheMgr() {
uint64_t CacheMgr::ItemCount() const {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0;
}
......@@ -28,6 +30,7 @@ uint64_t CacheMgr::ItemCount() const {
bool CacheMgr::ItemExists(const std::string& key) {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return false;
}
......@@ -36,6 +39,7 @@ bool CacheMgr::ItemExists(const std::string& key) {
DataObjPtr CacheMgr::GetItem(const std::string& key) {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return nullptr;
}
server::Metrics::GetInstance().CacheAccessTotalIncrement();
......@@ -53,6 +57,7 @@ engine::Index_ptr CacheMgr::GetIndex(const std::string& key) {
void CacheMgr::InsertItem(const std::string& key, const DataObjPtr& data) {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -62,6 +67,7 @@ void CacheMgr::InsertItem(const std::string& key, const DataObjPtr& data) {
void CacheMgr::InsertItem(const std::string& key, const engine::Index_ptr& index) {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -72,6 +78,7 @@ void CacheMgr::InsertItem(const std::string& key, const engine::Index_ptr& index
void CacheMgr::EraseItem(const std::string& key) {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -81,6 +88,7 @@ void CacheMgr::EraseItem(const std::string& key) {
void CacheMgr::PrintInfo() {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -89,6 +97,7 @@ void CacheMgr::PrintInfo() {
void CacheMgr::ClearCache() {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -97,6 +106,7 @@ void CacheMgr::ClearCache() {
int64_t CacheMgr::CacheUsage() const {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0;
}
......@@ -105,6 +115,7 @@ int64_t CacheMgr::CacheUsage() const {
int64_t CacheMgr::CacheCapacity() const {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0;
}
......@@ -113,6 +124,7 @@ int64_t CacheMgr::CacheCapacity() const {
void CacheMgr::SetCapacity(int64_t capacity) {
if(cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
cache_->set_capacity(capacity);
......
......@@ -12,10 +12,14 @@ namespace zilliz {
namespace milvus {
namespace cache {
namespace {
constexpr int64_t unit = 1024 * 1024 * 1024;
}
CpuCacheMgr::CpuCacheMgr() {
server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE);
int64_t cap = config.GetInt64Value(server::CONFIG_CPU_CACHE_CAPACITY, 16);
cap *= 1024*1024*1024;
cap *= unit;
cache_ = std::make_shared<Cache>(cap, 1UL<<32);
double free_percent = config.GetDoubleValue(server::CACHE_FREE_PERCENT, 0.85);
......
......@@ -11,10 +11,14 @@ namespace zilliz {
namespace milvus {
namespace cache {
namespace {
constexpr int64_t unit = 1024 * 1024 * 1024;
}
GpuCacheMgr::GpuCacheMgr() {
server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE);
int64_t cap = config.GetInt64Value(server::CONFIG_GPU_CACHE_CAPACITY, 1);
cap *= 1024*1024*1024;
cap *= unit;
cache_ = std::make_shared<Cache>(cap, 1UL<<32);
}
......
......@@ -94,7 +94,7 @@ double
ConfigNode::GetDoubleValue(const std::string &param_key, double default_val) const {
std::string val = GetValue(param_key);
if (!val.empty()) {
return std::strtold(val.c_str(), nullptr);
return std::strtod(val.c_str(), nullptr);
} else {
return default_val;
}
......
......@@ -9,14 +9,14 @@ namespace zilliz {
namespace milvus {
namespace engine {
const size_t K = 1024UL;
const size_t M = K * K;
const size_t G = K * M;
const size_t T = K * G;
constexpr size_t K = 1024UL;
constexpr size_t M = K * K;
constexpr size_t G = K * M;
constexpr size_t T = K * G;
const size_t MAX_TABLE_FILE_MEM = 128 * M;
constexpr size_t MAX_TABLE_FILE_MEM = 128 * M;
const int VECTOR_TYPE_SIZE = sizeof(float);
constexpr int VECTOR_TYPE_SIZE = sizeof(float);
} // namespace engine
} // namespace milvus
......
......@@ -12,11 +12,10 @@ namespace zilliz {
namespace milvus {
namespace engine {
DB::~DB() {}
DB::~DB() = default;
void DB::Open(const Options& options, DB** dbptr) {
*dbptr = DBFactory::Build(options);
return;
}
} // namespace engine
......
......@@ -52,7 +52,7 @@ public:
DB(const DB&) = delete;
DB& operator=(const DB&) = delete;
virtual ~DB();
virtual ~DB() = 0;
}; // DB
} // namespace engine
......
......@@ -89,7 +89,7 @@ DBImpl::DBImpl(const Options& options)
meta_ptr_ = DBMetaImplFactory::Build(options.meta, options.mode);
mem_mgr_ = MemManagerFactory::Build(meta_ptr_, options_);
if (options.mode != Options::MODE::READ_ONLY) {
ENGINE_LOG_INFO << "StartTimerTasks";
ENGINE_LOG_TRACE << "StartTimerTasks";
StartTimerTasks();
}
......@@ -102,6 +102,7 @@ Status DBImpl::CreateTable(meta::TableSchema& table_schema) {
Status DBImpl::DeleteTable(const std::string& table_id, const meta::DatesT& dates) {
//dates partly delete files of the table but currently we don't support
ENGINE_LOG_DEBUG << "Prepare to delete table " << table_id;
mem_mgr_->EraseMemVector(table_id); //not allow insert
meta_ptr_->DeleteTable(table_id); //soft delete table
......@@ -132,6 +133,7 @@ Status DBImpl::GetTableRowCount(const std::string& table_id, uint64_t& row_count
Status DBImpl::InsertVectors(const std::string& table_id_,
uint64_t n, const float* vectors, IDNumbers& vector_ids_) {
ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache";
auto start_time = METRICS_NOW_TIME;
Status status = mem_mgr_->InsertVectors(table_id_, n, vectors, vector_ids_);
......@@ -140,6 +142,8 @@ Status DBImpl::InsertVectors(const std::string& table_id_,
// std::chrono::microseconds time_span = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
// double average_time = double(time_span.count()) / n;
ENGINE_LOG_DEBUG << "Insert vectors to cache finished";
CollectInsertMetrics(total_time, n, status.ok());
return status;
......@@ -160,6 +164,8 @@ Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
const float* vectors, const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by vectors";
//get all table files from table
meta::DatePartionedTableFilesSchema files;
auto status = meta_ptr_->FilesToSearch(table_id, dates, files);
......@@ -181,6 +187,8 @@ Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq,
Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, const float* vectors,
const meta::DatesT& dates, QueryResults& results) {
ENGINE_LOG_DEBUG << "Query by file ids";
//get specified files
std::vector<size_t> ids;
for (auto &id : file_ids) {
......@@ -269,6 +277,8 @@ void DBImpl::BackgroundTimerTask() {
for(auto& iter : index_thread_results_) {
iter.wait();
}
ENGINE_LOG_DEBUG << "DB background thread exit";
break;
}
......@@ -287,6 +297,8 @@ void DBImpl::StartMetricTask() {
return;
}
ENGINE_LOG_TRACE << "Start metric task";
server::Metrics::GetInstance().KeepingAliveCounterIncrement(METRIC_ACTION_INTERVAL);
int64_t cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage();
int64_t cache_total = cache::CpuCacheMgr::GetInstance()->CacheCapacity();
......@@ -299,17 +311,14 @@ void DBImpl::StartMetricTask() {
server::Metrics::GetInstance().GPUPercentGaugeSet();
server::Metrics::GetInstance().GPUMemoryUsageGaugeSet();
server::Metrics::GetInstance().OctetsSet();
ENGINE_LOG_TRACE << "Metric task finished";
}
void DBImpl::StartCompactionTask() {
// static int count = 0;
// count++;
// std::cout << "StartCompactionTask: " << count << std::endl;
// std::cout << "c: " << count++ << std::endl;
static uint64_t compact_clock_tick = 0;
compact_clock_tick++;
if(compact_clock_tick%COMPACT_ACTION_INTERVAL != 0) {
// std::cout << "c r: " << count++ << std::endl;
return;
}
......@@ -320,6 +329,10 @@ void DBImpl::StartCompactionTask() {
compact_table_ids_.insert(id);
}
if(!temp_table_ids.empty()) {
SERVER_LOG_DEBUG << "Insert cache serialized";
}
//compactiong has been finished?
if(!compact_thread_results_.empty()) {
std::chrono::milliseconds span(10);
......@@ -338,13 +351,15 @@ void DBImpl::StartCompactionTask() {
Status DBImpl::MergeFiles(const std::string& table_id, const meta::DateT& date,
const meta::TableFilesSchema& files) {
ENGINE_LOG_DEBUG << "Merge files for table" << table_id;
meta::TableFileSchema table_file;
table_file.table_id_ = table_id;
table_file.date_ = date;
Status status = meta_ptr_->CreateTableFile(table_file);
if (!status.ok()) {
ENGINE_LOG_INFO << status.ToString() << std::endl;
ENGINE_LOG_ERROR << "Failed to create table: " << status.ToString();
return status;
}
......@@ -396,6 +411,7 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) {
meta::DatePartionedTableFilesSchema raw_files;
auto status = meta_ptr_->FilesToMerge(table_id, raw_files);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to get merge files for table: " << table_id;
return status;
}
......@@ -417,12 +433,14 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) {
}
void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
ENGINE_LOG_TRACE << " Background compaction thread start";
Status status;
for (auto& table_id : table_ids) {
status = BackgroundMergeFiles(table_id);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Merge files for table " << table_id << " failed: " << status.ToString();
return;
continue;//let other table get chance to merge
}
}
......@@ -433,6 +451,8 @@ void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
ttl = meta::D_SEC;
}
meta_ptr_->CleanUpFilesWithTTL(ttl);
ENGINE_LOG_TRACE << " Background compaction thread exit";
}
void DBImpl::StartBuildIndexTask(bool force) {
......@@ -477,6 +497,7 @@ Status DBImpl::BuildIndex(const std::string& table_id) {
Status DBImpl::BuildIndex(const meta::TableFileSchema& file) {
ExecutionEnginePtr to_index = EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_);
if(to_index == nullptr) {
ENGINE_LOG_ERROR << "Invalid engine type";
return Status::Error("Invalid engine type");
}
......@@ -491,6 +512,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) {
table_file.file_type_ = meta::TableFileSchema::INDEX; //for multi-db-path, distribute index file averagely to each path
Status status = meta_ptr_->CreateTableFile(table_file);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to create table: " << status.ToString();
return status;
}
......@@ -559,6 +581,8 @@ Status DBImpl::BuildIndexByTable(const std::string& table_id) {
}
void DBImpl::BackgroundBuildIndex() {
ENGINE_LOG_TRACE << " Background build index thread start";
std::unique_lock<std::mutex> lock(build_index_mutex_);
meta::TableFilesSchema to_index_files;
meta_ptr_->FilesToIndex(to_index_files);
......@@ -574,6 +598,8 @@ void DBImpl::BackgroundBuildIndex() {
break;
}
}
ENGINE_LOG_TRACE << " Background build index thread exit";
}
Status DBImpl::DropAll() {
......
......@@ -8,67 +8,88 @@
#include "Meta.h"
#include "Options.h"
namespace zilliz {
namespace milvus {
namespace engine {
namespace meta {
auto StoragePrototype(const std::string& path);
auto StoragePrototype(const std::string &path);
class DBMetaImpl : public Meta {
public:
DBMetaImpl(const DBMetaOptions& options_);
public:
explicit DBMetaImpl(const DBMetaOptions &options_);
Status
CreateTable(TableSchema &table_schema) override;
Status
DescribeTable(TableSchema &group_info_) override;
Status
HasTable(const std::string &table_id, bool &has_or_not) override;
Status
AllTables(std::vector<TableSchema> &table_schema_array) override;
Status
DeleteTable(const std::string &table_id) override;
virtual Status CreateTable(TableSchema& table_schema) override;
virtual Status DescribeTable(TableSchema& group_info_) override;
virtual Status HasTable(const std::string& table_id, bool& has_or_not) override;
virtual Status AllTables(std::vector<TableSchema>& table_schema_array) override;
Status
DeleteTableFiles(const std::string &table_id) override;
virtual Status DeleteTable(const std::string& table_id) override;
virtual Status DeleteTableFiles(const std::string& table_id) override;
Status
CreateTableFile(TableFileSchema &file_schema) override;
virtual Status CreateTableFile(TableFileSchema& file_schema) override;
virtual Status DropPartitionsByDates(const std::string& table_id,
const DatesT& dates) override;
Status
DropPartitionsByDates(const std::string &table_id, const DatesT &dates) override;
virtual Status GetTableFiles(const std::string& table_id,
const std::vector<size_t>& ids,
TableFilesSchema& table_files) override;
Status
GetTableFiles(const std::string &table_id, const std::vector<size_t> &ids, TableFilesSchema &table_files) override;
virtual Status HasNonIndexFiles(const std::string& table_id, bool& has) override;
Status
HasNonIndexFiles(const std::string &table_id, bool &has) override;
virtual Status UpdateTableFilesToIndex(const std::string& table_id) override;
Status
UpdateTableFilesToIndex(const std::string &table_id) override;
virtual Status UpdateTableFile(TableFileSchema& file_schema) override;
Status
UpdateTableFile(TableFileSchema &file_schema) override;
virtual Status UpdateTableFiles(TableFilesSchema& files) override;
Status
UpdateTableFiles(TableFilesSchema &files) override;
virtual Status FilesToSearch(const std::string& table_id,
const DatesT& partition,
DatePartionedTableFilesSchema& files) override;
Status
FilesToSearch(const std::string &table_id, const DatesT &partition, DatePartionedTableFilesSchema &files) override;
virtual Status FilesToMerge(const std::string& table_id,
DatePartionedTableFilesSchema& files) override;
Status
FilesToMerge(const std::string &table_id, DatePartionedTableFilesSchema &files) override;
virtual Status FilesToIndex(TableFilesSchema&) override;
Status
FilesToIndex(TableFilesSchema &) override;
virtual Status Archive() override;
Status
Archive() override;
virtual Status Size(uint64_t& result) override;
Status
Size(uint64_t &result) override;
virtual Status CleanUp() override;
Status
CleanUp() override;
virtual Status CleanUpFilesWithTTL(uint16_t seconds) override;
Status
CleanUpFilesWithTTL(uint16_t seconds) override;
virtual Status DropAll() override;
Status
DropAll() override;
virtual Status Count(const std::string& table_id, uint64_t& result) override;
Status Count(const std::string &table_id, uint64_t &result) override;
virtual ~DBMetaImpl();
~DBMetaImpl() override;
private:
Status NextFileId(std::string& file_id);
Status NextTableId(std::string& table_id);
private:
Status NextFileId(std::string &file_id);
Status NextTableId(std::string &table_id);
Status DiscardFiles(long to_discard_size);
Status Initialize();
......
......@@ -13,7 +13,9 @@ namespace zilliz {
namespace milvus {
namespace engine {
IDGenerator::~IDGenerator() {}
IDGenerator::~IDGenerator() = default;
constexpr size_t SimpleIDGenerator::MAX_IDS_PER_MICRO;
IDNumber SimpleIDGenerator::GetNextIDNumber() {
auto now = std::chrono::system_clock::now();
......
......@@ -10,28 +10,39 @@
#include <cstddef>
#include <vector>
namespace zilliz {
namespace milvus {
namespace engine {
class IDGenerator {
public:
virtual IDNumber GetNextIDNumber() = 0;
virtual void GetNextIDNumbers(size_t n, IDNumbers& ids) = 0;
public:
virtual
IDNumber GetNextIDNumber() = 0;
virtual ~IDGenerator();
virtual void
GetNextIDNumbers(size_t n, IDNumbers &ids) = 0;
virtual
~IDGenerator() = 0;
}; // IDGenerator
class SimpleIDGenerator : public IDGenerator {
public:
virtual IDNumber GetNextIDNumber() override;
virtual void GetNextIDNumbers(size_t n, IDNumbers& ids) override;
public:
~SimpleIDGenerator() override = default;
IDNumber
GetNextIDNumber() override;
void
GetNextIDNumbers(size_t n, IDNumbers &ids) override;
private:
void
NextIDNumbers(size_t n, IDNumbers &ids);
private:
void NextIDNumbers(size_t n, IDNumbers& ids);
const size_t MAX_IDS_PER_MICRO = 1000;
static constexpr size_t MAX_IDS_PER_MICRO = 1000;
}; // SimpleIDGenerator
......
......@@ -13,6 +13,8 @@ namespace milvus {
namespace engine {
namespace meta {
Meta::~Meta() = default;
DateT Meta::GetDate(const std::time_t& t, int day_delta) {
struct tm ltm;
localtime_r(&t, &ltm);
......
......@@ -20,56 +20,86 @@ namespace meta {
class Meta {
public:
public:
using Ptr = std::shared_ptr<Meta>;
virtual Status CreateTable(TableSchema& table_schema) = 0;
virtual Status DescribeTable(TableSchema& table_schema) = 0;
virtual Status HasTable(const std::string& table_id, bool& has_or_not) = 0;
virtual Status AllTables(std::vector<TableSchema>& table_schema_array) = 0;
virtual
~Meta() = 0;
virtual Status DeleteTable(const std::string& table_id) = 0;
virtual Status DeleteTableFiles(const std::string& table_id) = 0;
virtual Status
CreateTable(TableSchema &table_schema) = 0;
virtual Status CreateTableFile(TableFileSchema& file_schema) = 0;
virtual Status DropPartitionsByDates(const std::string& table_id,
const DatesT& dates) = 0;
virtual Status
DescribeTable(TableSchema &table_schema) = 0;
virtual Status GetTableFiles(const std::string& table_id,
const std::vector<size_t>& ids,
TableFilesSchema& table_files) = 0;
virtual Status
HasTable(const std::string &table_id, bool &has_or_not) = 0;
virtual Status UpdateTableFilesToIndex(const std::string& table_id) = 0;
virtual Status
AllTables(std::vector<TableSchema> &table_schema_array) = 0;
virtual Status UpdateTableFile(TableFileSchema& file_schema) = 0;
virtual Status
DeleteTable(const std::string &table_id) = 0;
virtual Status UpdateTableFiles(TableFilesSchema& files) = 0;
virtual Status
DeleteTableFiles(const std::string &table_id) = 0;
virtual Status FilesToSearch(const std::string &table_id,
const DatesT &partition,
DatePartionedTableFilesSchema& files) = 0;
virtual Status
CreateTableFile(TableFileSchema &file_schema) = 0;
virtual Status FilesToMerge(const std::string& table_id,
DatePartionedTableFilesSchema& files) = 0;
virtual Status
DropPartitionsByDates(const std::string &table_id, const DatesT &dates) = 0;
virtual Status Size(uint64_t& result) = 0;
virtual Status
GetTableFiles(const std::string &table_id, const std::vector<size_t> &ids, TableFilesSchema &table_files) = 0;
virtual Status Archive() = 0;
virtual Status
UpdateTableFilesToIndex(const std::string &table_id) = 0;
virtual Status FilesToIndex(TableFilesSchema&) = 0;
virtual Status
UpdateTableFile(TableFileSchema &file_schema) = 0;
virtual Status HasNonIndexFiles(const std::string& table_id, bool& has) = 0;
virtual Status
UpdateTableFiles(TableFilesSchema &files) = 0;
virtual Status CleanUp() = 0;
virtual Status CleanUpFilesWithTTL(uint16_t) = 0;
virtual Status
FilesToSearch(const std::string &table_id, const DatesT &partition, DatePartionedTableFilesSchema &files) = 0;
virtual Status DropAll() = 0;
virtual Status
FilesToMerge(const std::string &table_id, DatePartionedTableFilesSchema &files) = 0;
virtual Status Count(const std::string& table_id, uint64_t& result) = 0;
virtual Status
Size(uint64_t &result) = 0;
static DateT GetDate(const std::time_t& t, int day_delta = 0);
static DateT GetDate();
static DateT GetDateWithDelta(int day_delta);
virtual Status
Archive() = 0;
virtual Status
FilesToIndex(TableFilesSchema &) = 0;
virtual Status
HasNonIndexFiles(const std::string &table_id, bool &has) = 0;
virtual Status
CleanUp() = 0;
virtual Status
CleanUpFilesWithTTL(uint16_t) = 0;
virtual Status
DropAll() = 0;
virtual Status
Count(const std::string &table_id, uint64_t &result) = 0;
static DateT
GetDate(const std::time_t &t, int day_delta = 0);
static DateT
GetDate();
static DateT
GetDateWithDelta(int day_delta);
}; // MetaData
......
......@@ -12,79 +12,80 @@
#include "mysql++/mysql++.h"
#include <mutex>
namespace zilliz {
namespace milvus {
namespace engine {
namespace meta {
// auto StoragePrototype(const std::string& path);
using namespace mysqlpp;
using namespace mysqlpp;
class MySQLMetaImpl : public Meta {
public:
MySQLMetaImpl(const DBMetaOptions& options_, const int& mode);
class MySQLMetaImpl : public Meta {
public:
MySQLMetaImpl(const DBMetaOptions &options_, const int &mode);
virtual Status CreateTable(TableSchema& table_schema) override;
virtual Status DescribeTable(TableSchema& group_info_) override;
virtual Status HasTable(const std::string& table_id, bool& has_or_not) override;
virtual Status AllTables(std::vector<TableSchema>& table_schema_array) override;
Status CreateTable(TableSchema &table_schema) override;
Status DescribeTable(TableSchema &group_info_) override;
Status HasTable(const std::string &table_id, bool &has_or_not) override;
Status AllTables(std::vector<TableSchema> &table_schema_array) override;
virtual Status DeleteTable(const std::string& table_id) override;
virtual Status DeleteTableFiles(const std::string& table_id) override;
Status DeleteTable(const std::string &table_id) override;
Status DeleteTableFiles(const std::string &table_id) override;
virtual Status CreateTableFile(TableFileSchema& file_schema) override;
virtual Status DropPartitionsByDates(const std::string& table_id,
const DatesT& dates) override;
Status CreateTableFile(TableFileSchema &file_schema) override;
Status DropPartitionsByDates(const std::string &table_id,
const DatesT &dates) override;
virtual Status GetTableFiles(const std::string& table_id,
const std::vector<size_t>& ids,
TableFilesSchema& table_files) override;
Status GetTableFiles(const std::string &table_id,
const std::vector<size_t> &ids,
TableFilesSchema &table_files) override;
virtual Status HasNonIndexFiles(const std::string& table_id, bool& has) override;
Status HasNonIndexFiles(const std::string &table_id, bool &has) override;
virtual Status UpdateTableFile(TableFileSchema& file_schema) override;
Status UpdateTableFile(TableFileSchema &file_schema) override;
virtual Status UpdateTableFilesToIndex(const std::string& table_id) override;
Status UpdateTableFilesToIndex(const std::string &table_id) override;
virtual Status UpdateTableFiles(TableFilesSchema& files) override;
Status UpdateTableFiles(TableFilesSchema &files) override;
virtual Status FilesToSearch(const std::string& table_id,
const DatesT& partition,
DatePartionedTableFilesSchema& files) override;
Status FilesToSearch(const std::string &table_id,
const DatesT &partition,
DatePartionedTableFilesSchema &files) override;
virtual Status FilesToMerge(const std::string& table_id,
DatePartionedTableFilesSchema& files) override;
Status FilesToMerge(const std::string &table_id,
DatePartionedTableFilesSchema &files) override;
virtual Status FilesToIndex(TableFilesSchema&) override;
Status FilesToIndex(TableFilesSchema &) override;
virtual Status Archive() override;
Status Archive() override;
virtual Status Size(uint64_t& result) override;
Status Size(uint64_t &result) override;
virtual Status CleanUp() override;
Status CleanUp() override;
virtual Status CleanUpFilesWithTTL(uint16_t seconds) override;
Status CleanUpFilesWithTTL(uint16_t seconds) override;
virtual Status DropAll() override;
Status DropAll() override;
virtual Status Count(const std::string& table_id, uint64_t& result) override;
Status Count(const std::string &table_id, uint64_t &result) override;
virtual ~MySQLMetaImpl();
virtual ~MySQLMetaImpl();
private:
Status NextFileId(std::string& file_id);
Status NextTableId(std::string& table_id);
Status DiscardFiles(long long to_discard_size);
Status Initialize();
private:
Status NextFileId(std::string &file_id);
Status NextTableId(std::string &table_id);
Status DiscardFiles(long long to_discard_size);
Status Initialize();
const DBMetaOptions options_;
const int mode_;
const DBMetaOptions options_;
const int mode_;
std::shared_ptr<MySQLConnectionPool> mysql_connection_pool_;
bool safe_grab = false;
std::shared_ptr<MySQLConnectionPool> mysql_connection_pool_;
bool safe_grab = false;
// std::mutex connectionMutex_;
}; // DBMetaImpl
}; // DBMetaImpl
} // namespace meta
} // namespace engine
......
......@@ -20,6 +20,7 @@ class ReuseCacheIndexStrategy {
public:
bool Schedule(const SearchContextPtr &context, std::list<ScheduleTaskPtr>& task_list) {
if(context == nullptr) {
ENGINE_LOG_ERROR << "Task Dispatch context doesn't exist";
return false;
}
......@@ -64,6 +65,7 @@ class DeleteTableStrategy {
public:
bool Schedule(const DeleteContextPtr &context, std::list<ScheduleTaskPtr> &task_list) {
if (context == nullptr) {
ENGINE_LOG_ERROR << "Task Dispatch context doesn't exist";
return false;
}
......@@ -103,6 +105,7 @@ public:
bool TaskDispatchStrategy::Schedule(const ScheduleContextPtr &context_ptr,
std::list<zilliz::milvus::engine::ScheduleTaskPtr> &task_list) {
if(context_ptr == nullptr) {
ENGINE_LOG_ERROR << "Task Dispatch context doesn't exist";
return false;
}
......
......@@ -31,6 +31,7 @@ TaskScheduler& TaskScheduler::GetInstance() {
bool
TaskScheduler::Start() {
if(!stopped_) {
SERVER_LOG_INFO << "Task Scheduler isn't started";
return true;
}
......@@ -47,6 +48,7 @@ TaskScheduler::Start() {
bool
TaskScheduler::Stop() {
if(stopped_) {
SERVER_LOG_INFO << "Task Scheduler already stopped";
return true;
}
......@@ -80,7 +82,7 @@ TaskScheduler::TaskDispatchWorker() {
ScheduleTaskPtr task_ptr = task_dispatch_queue_.Take();
if(task_ptr == nullptr) {
SERVER_LOG_INFO << "Stop db task dispatch thread";
break;//exit
return true;
}
//execute task
......@@ -98,8 +100,8 @@ TaskScheduler::TaskWorker() {
while(true) {
ScheduleTaskPtr task_ptr = task_queue_.Take();
if(task_ptr == nullptr) {
SERVER_LOG_INFO << "Stop db task thread";
break;//exit
SERVER_LOG_INFO << "Stop db task worker thread";
return true;
}
//execute task
......
......@@ -5,14 +5,60 @@
******************************************************************************/
#include "SearchTask.h"
#include "metrics/Metrics.h"
#include "utils/Log.h"
#include "db/Log.h"
#include "utils/TimeRecorder.h"
#include <thread>
namespace zilliz {
namespace milvus {
namespace engine {
namespace {
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;
bool NeedParallelReduce(uint64_t nq, uint64_t topk) {
server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode& db_config = config.GetConfig(server::CONFIG_DB);
bool need_parallel = db_config.GetBoolValue(server::CONFIG_DB_PARALLEL_REDUCE, true);
if(!need_parallel) {
return false;
}
return nq*topk >= PARALLEL_REDUCE_THRESHOLD;
}
void ParallelReduce(std::function<void(size_t, size_t)>& reduce_function, size_t max_index) {
size_t reduce_batch = PARALLEL_REDUCE_BATCH;
auto thread_count = std::thread::hardware_concurrency() - 1; //not all core do this work
if(thread_count > 0) {
reduce_batch = max_index/thread_count + 1;
}
ENGINE_LOG_DEBUG << "use " << thread_count <<
" thread parallelly do reduce, each thread process " << reduce_batch << " vectors";
std::vector<std::shared_ptr<std::thread> > thread_array;
size_t from_index = 0;
while(from_index < max_index) {
size_t to_index = from_index + reduce_batch;
if(to_index > max_index) {
to_index = max_index;
}
auto reduce_thread = std::make_shared<std::thread>(reduce_function, from_index, to_index);
thread_array.push_back(reduce_thread);
from_index = to_index;
}
for(auto& thread_ptr : thread_array) {
thread_ptr->join();
}
}
void CollectDurationMetrics(int index_type, double total_time) {
switch(index_type) {
case meta::TableFileSchema::RAW: {
......@@ -32,7 +78,7 @@ void CollectDurationMetrics(int index_type, double total_time) {
std::string GetMetricType() {
server::ServerConfig &config = server::ServerConfig::GetInstance();
server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
server::ConfigNode& engine_config = config.GetConfig(server::CONFIG_ENGINE);
return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
}
......@@ -51,7 +97,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
return nullptr;
}
SERVER_LOG_DEBUG << "Searching in file id:" << index_id_<< " with "
ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_<< " with "
<< search_contexts_.size() << " tasks";
server::TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_));
......@@ -79,6 +125,9 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
span = rc.RecordSection("cluster result for context:" + context->Identity());
context->AccumReduceCost(span);
//step 4: pick up topk result
SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
......@@ -86,7 +135,7 @@ std::shared_ptr<IScheduleTask> SearchTask::Execute() {
context->AccumReduceCost(span);
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
continue;
}
......@@ -112,23 +161,32 @@ Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
if(output_ids.size() < nq*topk || output_distence.size() < nq*topk) {
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
" distance array size: " + std::to_string(output_distence.size());
SERVER_LOG_ERROR << msg;
ENGINE_LOG_ERROR << msg;
return Status::Error(msg);
}
result_set.clear();
result_set.reserve(nq);
for (auto i = 0; i < nq; i++) {
SearchContext::Id2DistanceMap id_distance;
id_distance.reserve(topk);
for (auto k = 0; k < topk; k++) {
uint64_t index = i * topk + k;
if(output_ids[index] < 0) {
continue;
result_set.resize(nq);
std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
for (auto i = from_index; i < to_index; i++) {
SearchContext::Id2DistanceMap id_distance;
id_distance.reserve(topk);
for (auto k = 0; k < topk; k++) {
uint64_t index = i * topk + k;
if(output_ids[index] < 0) {
continue;
}
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
}
id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
result_set[i] = id_distance;
}
result_set.emplace_back(id_distance);
};
if(NeedParallelReduce(nq, topk)) {
ParallelReduce(reduce_worker, nq);
} else {
reduce_worker(0, nq);
}
return Status::OK();
......@@ -140,7 +198,7 @@ Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
bool ascending) {
//Note: the score_src and score_target are already arranged by score in ascending order
if(distance_src.empty()) {
SERVER_LOG_WARNING << "Empty distance source array";
ENGINE_LOG_WARNING << "Empty distance source array";
return Status::OK();
}
......@@ -218,14 +276,22 @@ Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
if (result_src.size() != result_target.size()) {
std::string msg = "Invalid result set size";
SERVER_LOG_ERROR << msg;
ENGINE_LOG_ERROR << msg;
return Status::Error(msg);
}
for (size_t i = 0; i < result_src.size(); i++) {
SearchContext::Id2DistanceMap &score_src = result_src[i];
SearchContext::Id2DistanceMap &score_target = result_target[i];
SearchTask::MergeResult(score_src, score_target, topk, ascending);
std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
for (size_t i = from_index; i < to_index; i++) {
SearchContext::Id2DistanceMap &score_src = result_src[i];
SearchContext::Id2DistanceMap &score_target = result_target[i];
SearchTask::MergeResult(score_src, score_target, topk, ascending);
}
};
if(NeedParallelReduce(result_src.size(), topk)) {
ParallelReduce(ReduceWorker, result_src.size());
} else {
ReduceWorker(0, result_src.size());
}
return Status::OK();
......
......@@ -233,21 +233,22 @@ ClientTest::Test(const std::string& address, const std::string& port) {
PrintTableSchema(tb_schema);
}
//add vectors
std::vector<std::pair<int64_t, RowRecord>> search_record_array;
{//add vectors
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors
TimeRecorder recorder("Add vector No." + std::to_string(i));
std::vector<RowRecord> record_array;
int64_t begin_index = i * BATCH_ROW_COUNT;
BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array);
std::vector<int64_t> record_ids;
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
std::cout << "AddVector function call status: " << stat.ToString() << std::endl;
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
if(search_record_array.size() < NQ) {
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {
TimeRecorder recorder("Add vector No." + std::to_string(i));
std::vector<RowRecord> record_array;
int64_t begin_index = i * BATCH_ROW_COUNT;
BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array);
std::vector<int64_t> record_ids;
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
std::cout << "AddVector function call status: " << stat.ToString() << std::endl;
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
if(i == 0) {
for(int64_t k = SEARCH_TARGET; k < SEARCH_TARGET + NQ; k++) {
search_record_array.push_back(
std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET]));
std::make_pair(record_ids[k], record_array[k]));
}
}
}
......
......@@ -191,6 +191,7 @@ ServerError CreateTableTask::OnExecute() {
}
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "CreateTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......@@ -236,6 +237,7 @@ ServerError DescribeTableTask::OnExecute() {
schema_.store_raw_vector = table_info.store_raw_data_;
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "DescribeTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......@@ -279,6 +281,7 @@ ServerError BuildIndexTask::OnExecute() {
rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "BuildIndexTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......@@ -316,6 +319,7 @@ ServerError HasTableTask::OnExecute() {
rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "HasTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......@@ -365,6 +369,7 @@ ServerError DeleteTableTask::OnExecute() {
rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "DeleteTableTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......@@ -481,6 +486,7 @@ ServerError AddVectorTask::OnExecute() {
rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "AddVectorTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......@@ -604,6 +610,7 @@ ServerError SearchVectorTaskBase::OnExecute() {
<< " construct result(" << (span_result/total_cost)*100.0 << "%)";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "SearchVectorTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......@@ -739,6 +746,7 @@ ServerError GetTableRowCountTask::OnExecute() {
rc.ElapseFromBegin("totally cost");
} catch (std::exception& ex) {
SERVER_LOG_ERROR << "GetTableRowCountTask encounter exception: " << ex.what();
return SetError(SERVER_UNEXPECTED_ERROR, ex.what());
}
......
......@@ -29,6 +29,7 @@ static const std::string CONFIG_DB_INDEX_TRIGGER_SIZE = "index_building_threshol
static const std::string CONFIG_DB_ARCHIVE_DISK = "archive_disk_threshold";
static const std::string CONFIG_DB_ARCHIVE_DAYS = "archive_days_threshold";
static const std::string CONFIG_DB_INSERT_BUFFER_SIZE = "insert_buffer_size";
static const std::string CONFIG_DB_PARALLEL_REDUCE = "parallel_reduce";
static const std::string CONFIG_LOG = "log_config";
......
......@@ -6,6 +6,8 @@
#include <gtest/gtest.h>
#include "db/scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
#include <cmath>
#include <vector>
......@@ -17,27 +19,33 @@ static constexpr uint64_t NQ = 15;
static constexpr uint64_t TOP_K = 64;
void BuildResult(uint64_t nq,
uint64_t top_k,
uint64_t topk,
bool ascending,
std::vector<long> &output_ids,
std::vector<float> &output_distence) {
output_ids.clear();
output_ids.resize(nq*top_k);
output_ids.resize(nq*topk);
output_distence.clear();
output_distence.resize(nq*top_k);
output_distence.resize(nq*topk);
for(uint64_t i = 0; i < nq; i++) {
for(uint64_t j = 0; j < top_k; j++) {
output_ids[i * top_k + j] = (long)(drand48()*100000);
output_distence[i * top_k + j] = j + drand48();
for(uint64_t j = 0; j < topk; j++) {
output_ids[i * topk + j] = (long)(drand48()*100000);
output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
}
}
}
void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
const engine::SearchContext::Id2DistanceMap& src_2,
const engine::SearchContext::Id2DistanceMap& target) {
const engine::SearchContext::Id2DistanceMap& target,
bool ascending) {
for(uint64_t i = 0; i < target.size() - 1; i++) {
ASSERT_LE(target[i].second, target[i + 1].second);
if(ascending) {
ASSERT_LE(target[i].second, target[i + 1].second);
} else {
ASSERT_GE(target[i].second, target[i + 1].second);
}
}
using ID2DistMap = std::map<long, float>;
......@@ -57,9 +65,52 @@ void CheckResult(const engine::SearchContext::Id2DistanceMap& src_1,
}
}
void CheckCluster(const std::vector<long>& target_ids,
const std::vector<float>& target_distence,
const engine::SearchContext::ResultSet& src_result,
int64_t nq,
int64_t topk) {
ASSERT_EQ(src_result.size(), nq);
for(int64_t i = 0; i < nq; i++) {
auto& res = src_result[i];
ASSERT_EQ(res.size(), topk);
if(res.empty()) {
continue;
}
ASSERT_EQ(res[0].first, target_ids[i*topk]);
ASSERT_EQ(res[topk - 1].first, target_ids[i*topk + topk - 1]);
}
}
void CheckTopkResult(const engine::SearchContext::ResultSet& src_result,
bool ascending,
int64_t nq,
int64_t topk) {
ASSERT_EQ(src_result.size(), nq);
for(int64_t i = 0; i < nq; i++) {
auto& res = src_result[i];
ASSERT_EQ(res.size(), topk);
if(res.empty()) {
continue;
}
for(int64_t k = 0; k < topk - 1; k++) {
if(ascending) {
ASSERT_LE(res[k].second, res[k + 1].second);
} else {
ASSERT_GE(res[k].second, res[k + 1].second);
}
}
}
}
}
TEST(DBSearchTest, TOPK_TEST) {
bool ascending = true;
std::vector<long> target_ids;
std::vector<float> target_distence;
engine::SearchContext::ResultSet src_result;
......@@ -67,19 +118,19 @@ TEST(DBSearchTest, TOPK_TEST) {
ASSERT_FALSE(status.ok());
ASSERT_TRUE(src_result.empty());
BuildResult(NQ, TOP_K, target_ids, target_distence);
BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
status = engine::SearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
ASSERT_TRUE(status.ok());
ASSERT_EQ(src_result.size(), NQ);
engine::SearchContext::ResultSet target_result;
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, target_result);
status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(target_result, TOP_K, true, src_result);
status = engine::SearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
ASSERT_FALSE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
ASSERT_TRUE(src_result.empty());
ASSERT_EQ(target_result.size(), NQ);
......@@ -87,21 +138,21 @@ TEST(DBSearchTest, TOPK_TEST) {
std::vector<long> src_ids;
std::vector<float> src_distence;
uint64_t wrong_topk = TOP_K - 10;
BuildResult(NQ, wrong_topk, src_ids, src_distence);
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
status = engine::SearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
}
wrong_topk = TOP_K + 10;
BuildResult(NQ, wrong_topk, src_ids, src_distence);
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
status = engine::SearchTask::TopkResult(src_result, TOP_K, true, target_result);
status = engine::SearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
ASSERT_TRUE(status.ok());
for(uint64_t i = 0; i < NQ; i++) {
ASSERT_EQ(target_result[i].size(), TOP_K);
......@@ -109,6 +160,7 @@ TEST(DBSearchTest, TOPK_TEST) {
}
TEST(DBSearchTest, MERGE_TEST) {
bool ascending = true;
std::vector<long> target_ids;
std::vector<float> target_distence;
std::vector<long> src_ids;
......@@ -116,8 +168,8 @@ TEST(DBSearchTest, MERGE_TEST) {
engine::SearchContext::ResultSet src_result, target_result;
uint64_t src_count = 5, target_count = 8;
BuildResult(1, src_count, src_ids, src_distence);
BuildResult(1, target_count, target_ids, target_distence);
BuildResult(1, src_count, ascending, src_ids, src_distence);
BuildResult(1, target_count, ascending, target_ids, target_distence);
auto status = engine::SearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
ASSERT_TRUE(status.ok());
status = engine::SearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
......@@ -126,37 +178,107 @@ TEST(DBSearchTest, MERGE_TEST) {
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 10, true);
status = engine::SearchTask::MergeResult(src, target, 10, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), 10);
CheckResult(src_result[0], target_result[0], target);
CheckResult(src_result[0], target_result[0], target, ascending);
}
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target;
status = engine::SearchTask::MergeResult(src, target, 10, true);
status = engine::SearchTask::MergeResult(src, target, 10, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count);
ASSERT_TRUE(src.empty());
CheckResult(src_result[0], target_result[0], target);
CheckResult(src_result[0], target_result[0], target, ascending);
}
{
engine::SearchContext::Id2DistanceMap src = src_result[0];
engine::SearchContext::Id2DistanceMap target = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30, true);
status = engine::SearchTask::MergeResult(src, target, 30, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target);
CheckResult(src_result[0], target_result[0], target, ascending);
}
{
engine::SearchContext::Id2DistanceMap target = src_result[0];
engine::SearchContext::Id2DistanceMap src = target_result[0];
status = engine::SearchTask::MergeResult(src, target, 30, true);
status = engine::SearchTask::MergeResult(src, target, 30, ascending);
ASSERT_TRUE(status.ok());
ASSERT_EQ(target.size(), src_count + target_count);
CheckResult(src_result[0], target_result[0], target);
CheckResult(src_result[0], target_result[0], target, ascending);
}
}
TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
bool ascending = true;
std::vector<long> target_ids;
std::vector<float> target_distence;
engine::SearchContext::ResultSet src_result;
auto DoCluster = [&](int64_t nq, int64_t topk) {
server::TimeRecorder rc("DoCluster");
src_result.clear();
BuildResult(nq, topk, ascending, target_ids, target_distence);
rc.RecordSection("build id/dietance map");
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
ASSERT_TRUE(status.ok());
ASSERT_EQ(src_result.size(), nq);
rc.RecordSection("cluster result");
CheckCluster(target_ids, target_distence, src_result, nq, topk);
rc.RecordSection("check result");
};
DoCluster(10000, 1000);
DoCluster(333, 999);
DoCluster(1, 1000);
DoCluster(1, 1);
DoCluster(7, 0);
DoCluster(9999, 1);
DoCluster(10001, 1);
DoCluster(58273, 1234);
}
TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
std::vector<long> target_ids;
std::vector<float> target_distence;
engine::SearchContext::ResultSet src_result;
std::vector<long> insufficient_ids;
std::vector<float> insufficient_distence;
engine::SearchContext::ResultSet insufficient_result;
auto DoTopk = [&](int64_t nq, int64_t topk,int64_t insufficient_topk, bool ascending) {
src_result.clear();
insufficient_result.clear();
server::TimeRecorder rc("DoCluster");
BuildResult(nq, topk, ascending, target_ids, target_distence);
auto status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
rc.RecordSection("cluster result");
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
status = engine::SearchTask::ClusterResult(target_ids, target_distence, nq, insufficient_topk, insufficient_result);
rc.RecordSection("cluster result");
engine::SearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
ASSERT_TRUE(status.ok());
rc.RecordSection("topk");
CheckTopkResult(src_result, ascending, nq, topk);
rc.RecordSection("check result");
};
DoTopk(5, 10, 4, false);
DoTopk(20005, 998, 123, true);
DoTopk(9987, 12, 10, false);
DoTopk(77777, 1000, 1, false);
DoTopk(5432, 8899, 8899, true);
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册