提交 6b612ffc 编写于 作者: Y Yang Xuan

Merge remote-tracking branch 'mega/branch-0.3.0' into branch-0.3.0


Former-commit-id: 925bf53361e59ad626892bfce3299d49f5a2560a
......@@ -12,6 +12,8 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-57 - Implement index load/search pipeline
- MS-56 - Add version information when server is started
- MS-64 - Different table can have different index type
- MS-52 - Return search score
## Task
......
......@@ -85,7 +85,6 @@ set(third_party_libs
prometheus-cpp-pull
prometheus-cpp-core
civetweb
rocksdb
boost_system_static
boost_filesystem_static
boost_serialization_static
......@@ -165,7 +164,6 @@ endif ()
set(server_libs
vecwise_engine
librocksdb.a
libthrift.a
pthread
libyaml-cpp.a
......
......@@ -16,15 +16,7 @@ namespace engine {
DB::~DB() {}
void DB::Open(const Options& options, DB** dbptr) {
*dbptr = nullptr;
#ifdef GPU_VERSION
std::string default_index_type{"Faiss,IVF"};
#else
std::string default_index_type{"Faiss,IDMap"};
#endif
*dbptr = DBFactory::Build(options, default_index_type);
*dbptr = DBFactory::Build(options);
return;
}
......
......@@ -8,12 +8,12 @@
#include "DB.h"
#include "MemManager.h"
#include "Types.h"
#include "Traits.h"
#include <mutex>
#include <condition_variable>
#include <memory>
#include <atomic>
#include <thread>
namespace zilliz {
namespace vecwise {
......@@ -25,11 +25,10 @@ namespace meta {
class Meta;
}
template <typename EngineT>
class DBImpl : public DB {
public:
using MetaPtr = meta::Meta::Ptr;
using MemManagerPtr = typename MemManager<EngineT>::Ptr;
using MemManagerPtr = typename MemManager::Ptr;
DBImpl(const Options& options);
......@@ -100,5 +99,3 @@ private:
} // namespace engine
} // namespace vecwise
} // namespace zilliz
#include "DBImpl.inl"
此差异已折叠。
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "EngineFactory.h"
#include "FaissExecutionEngine.h"
#include "Log.h"
namespace zilliz {
namespace vecwise {
namespace engine {
ExecutionEnginePtr
EngineFactory::Build(uint16_t dimension,
const std::string& location,
EngineType type) {
switch(type) {
case EngineType::FAISS_IDMAP:
return ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, "IDMap", "IDMap,Flat"));
case EngineType::FAISS_IVFFLAT:
return ExecutionEnginePtr(new FaissExecutionEngine(dimension, location, "IVF", "IDMap,Flat"));
default:
ENGINE_LOG_ERROR << "Unsupportted engine type";
return nullptr;
}
}
}
}
}
\ No newline at end of file
......@@ -5,22 +5,20 @@
******************************************************************************/
#pragma once
#include "Status.h"
#include "ExecutionEngine.h"
namespace zilliz {
namespace vecwise {
namespace engine {
struct IVFIndexTrait {
static const char* BuildIndexType;
static const char* RawIndexType;
class EngineFactory {
public:
static ExecutionEnginePtr Build(uint16_t dimension,
const std::string& location,
EngineType type);
};
struct IDMapIndexTrait {
static const char* BuildIndexType;
static const char* RawIndexType;
};
} // namespace engine
} // namespace vecwise
} // namespace zilliz
}
}
}
......@@ -11,8 +11,7 @@ namespace zilliz {
namespace vecwise {
namespace engine {
template<typename Derived>
Status ExecutionEngine<Derived>::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
Status ExecutionEngine::AddWithIds(const std::vector<float>& vectors, const std::vector<long>& vector_ids) {
long n1 = (long)vectors.size();
long n2 = (long)vector_ids.size();
if (n1 != n2) {
......@@ -22,60 +21,6 @@ Status ExecutionEngine<Derived>::AddWithIds(const std::vector<float>& vectors, c
return AddWithIds(n1, vectors.data(), vector_ids.data());
}
template<typename Derived>
Status ExecutionEngine<Derived>::AddWithIds(long n, const float *xdata, const long *xids) {
return static_cast<Derived*>(this)->AddWithIds(n, xdata, xids);
}
template<typename Derived>
size_t ExecutionEngine<Derived>::Count() const {
return static_cast<Derived*>(this)->Count();
}
template<typename Derived>
size_t ExecutionEngine<Derived>::Size() const {
return static_cast<Derived*>(this)->Size();
}
template<typename Derived>
size_t ExecutionEngine<Derived>::PhysicalSize() const {
return static_cast<Derived*>(this)->PhysicalSize();
}
template<typename Derived>
Status ExecutionEngine<Derived>::Serialize() {
return static_cast<Derived*>(this)->Serialize();
}
template<typename Derived>
Status ExecutionEngine<Derived>::Load() {
return static_cast<Derived*>(this)->Load();
}
template<typename Derived>
Status ExecutionEngine<Derived>::Merge(const std::string& location) {
return static_cast<Derived*>(this)->Merge(location);
}
template<typename Derived>
Status ExecutionEngine<Derived>::Search(long n,
const float *data,
long k,
float *distances,
long *labels) const {
return static_cast<Derived*>(this)->Search(n, data, k, distances, labels);
}
template<typename Derived>
Status ExecutionEngine<Derived>::Cache() {
return static_cast<Derived*>(this)->Cache();
}
template<typename Derived>
std::shared_ptr<Derived> ExecutionEngine<Derived>::BuildIndex(const std::string& location) {
return static_cast<Derived*>(this)->BuildIndex(location);
}
} // namespace engine
} // namespace vecwise
......
......@@ -14,38 +14,47 @@ namespace zilliz {
namespace vecwise {
namespace engine {
template <typename Derived>
enum class EngineType {
INVALID = 0,
FAISS_IDMAP = 1,
FAISS_IVFFLAT,
};
class ExecutionEngine {
public:
Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
virtual Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
Status AddWithIds(long n, const float *xdata, const long *xids);
virtual Status AddWithIds(long n, const float *xdata, const long *xids) = 0;
size_t Count() const;
virtual size_t Count() const = 0;
size_t Size() const;
virtual size_t Size() const = 0;
size_t PhysicalSize() const;
virtual size_t Dimension() const = 0;
Status Serialize();
virtual size_t PhysicalSize() const = 0;
Status Load();
virtual Status Serialize() = 0;
Status Merge(const std::string& location);
virtual Status Load() = 0;
Status Search(long n,
virtual Status Merge(const std::string& location) = 0;
virtual Status Search(long n,
const float *data,
long k,
float *distances,
long *labels) const;
long *labels) const = 0;
std::shared_ptr<Derived> BuildIndex(const std::string&);
virtual std::shared_ptr<ExecutionEngine> BuildIndex(const std::string&) = 0;
Status Cache();
virtual Status Cache() = 0;
};
using ExecutionEnginePtr = std::shared_ptr<ExecutionEngine>;
} // namespace engine
} // namespace vecwise
......
......@@ -5,8 +5,6 @@
////////////////////////////////////////////////////////////////////////////////
#include "Factories.h"
#include "DBImpl.h"
#include "FaissExecutionEngine.h"
#include "Traits.h"
#include <stdlib.h>
#include <time.h>
......@@ -45,28 +43,14 @@ std::shared_ptr<meta::DBMetaImpl> DBMetaImplFactory::Build() {
return std::shared_ptr<meta::DBMetaImpl>(new meta::DBMetaImpl(options));
}
std::shared_ptr<DB> DBFactory::Build(const std::string& db_type) {
std::shared_ptr<DB> DBFactory::Build() {
auto options = OptionsFactory::Build();
auto db = DBFactory::Build(options, db_type);
auto db = DBFactory::Build(options);
return std::shared_ptr<DB>(db);
}
DB* DBFactory::Build(const Options& options, const std::string& db_type) {
std::stringstream ss(db_type);
std::string token;
std::vector<std::string> tokens;
while (std::getline(ss, token, ',')) {
tokens.push_back(token);
}
assert(tokens.size()==2);
assert(tokens[0]=="Faiss");
if (tokens[1] == "IVF") {
return new DBImpl<FaissExecutionEngine<IVFIndexTrait>>(options);
} else if (tokens[1] == "IDMap") {
return new DBImpl<FaissExecutionEngine<IDMapIndexTrait>>(options);
}
return nullptr;
DB* DBFactory::Build(const Options& options) {
return new DBImpl(options);
}
} // namespace engine
......
......@@ -8,6 +8,7 @@
#include "DB.h"
#include "DBMetaImpl.h"
#include "Options.h"
#include "ExecutionEngine.h"
#include <string>
#include <memory>
......@@ -29,8 +30,8 @@ struct DBMetaImplFactory {
};
struct DBFactory {
static std::shared_ptr<DB> Build(const std::string& db_type = "Faiss,IVF");
static DB* Build(const Options&, const std::string& db_type = "Faiss,IVF");
static std::shared_ptr<DB> Build();
static DB* Build(const Options&);
};
} // namespace engine
......
......@@ -3,8 +3,6 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "FaissExecutionEngine.h"
#include <easylogging++.h>
......@@ -23,47 +21,53 @@ namespace vecwise {
namespace engine {
template<class IndexTrait>
FaissExecutionEngine<IndexTrait>::FaissExecutionEngine(uint16_t dimension, const std::string& location)
: pIndex_(faiss::index_factory(dimension, IndexTrait::RawIndexType)),
location_(location) {
FaissExecutionEngine::FaissExecutionEngine(uint16_t dimension,
const std::string& location,
const std::string& build_index_type,
const std::string& raw_index_type)
: pIndex_(faiss::index_factory(dimension, raw_index_type.c_str())),
location_(location),
build_index_type_(build_index_type),
raw_index_type_(raw_index_type) {
}
template<class IndexTrait>
FaissExecutionEngine<IndexTrait>::FaissExecutionEngine(std::shared_ptr<faiss::Index> index, const std::string& location)
FaissExecutionEngine::FaissExecutionEngine(std::shared_ptr<faiss::Index> index,
const std::string& location,
const std::string& build_index_type,
const std::string& raw_index_type)
: pIndex_(index),
location_(location) {
location_(location),
build_index_type_(build_index_type),
raw_index_type_(raw_index_type) {
}
template<class IndexTrait>
Status FaissExecutionEngine<IndexTrait>::AddWithIds(long n, const float *xdata, const long *xids) {
Status FaissExecutionEngine::AddWithIds(long n, const float *xdata, const long *xids) {
pIndex_->add_with_ids(n, xdata, xids);
return Status::OK();
}
template<class IndexTrait>
size_t FaissExecutionEngine<IndexTrait>::Count() const {
size_t FaissExecutionEngine::Count() const {
return (size_t)(pIndex_->ntotal);
}
template<class IndexTrait>
size_t FaissExecutionEngine<IndexTrait>::Size() const {
size_t FaissExecutionEngine::Size() const {
return (size_t)(Count() * pIndex_->d)*sizeof(float);
}
template<class IndexTrait>
size_t FaissExecutionEngine<IndexTrait>::PhysicalSize() const {
size_t FaissExecutionEngine::Dimension() const {
return pIndex_->d;
}
size_t FaissExecutionEngine::PhysicalSize() const {
return (size_t)(Count() * pIndex_->d)*sizeof(float);
}
template<class IndexTrait>
Status FaissExecutionEngine<IndexTrait>::Serialize() {
Status FaissExecutionEngine::Serialize() {
write_index(pIndex_.get(), location_.c_str());
return Status::OK();
}
template<class IndexTrait>
Status FaissExecutionEngine<IndexTrait>::Load() {
Status FaissExecutionEngine::Load() {
auto index = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(location_);
bool to_cache = false;
auto start_time = METRICS_NOW_TIME;
......@@ -90,8 +94,7 @@ Status FaissExecutionEngine<IndexTrait>::Load() {
return Status::OK();
}
template<class IndexTrait>
Status FaissExecutionEngine<IndexTrait>::Merge(const std::string& location) {
Status FaissExecutionEngine::Merge(const std::string& location) {
if (location == location_) {
return Status::Error("Cannot Merge Self");
}
......@@ -105,12 +108,11 @@ Status FaissExecutionEngine<IndexTrait>::Merge(const std::string& location) {
return Status::OK();
}
template<class IndexTrait>
typename FaissExecutionEngine<IndexTrait>::Ptr
FaissExecutionEngine<IndexTrait>::BuildIndex(const std::string& location) {
ExecutionEnginePtr
FaissExecutionEngine::BuildIndex(const std::string& location) {
auto opd = std::make_shared<Operand>();
opd->d = pIndex_->d;
opd->index_type = IndexTrait::BuildIndexType;
opd->index_type = build_index_type_;
IndexBuilderPtr pBuilder = GetIndexBuilder(opd);
auto from_index = dynamic_cast<faiss::IndexIDMap*>(pIndex_.get());
......@@ -119,13 +121,12 @@ FaissExecutionEngine<IndexTrait>::BuildIndex(const std::string& location) {
dynamic_cast<faiss::IndexFlat*>(from_index->index)->xb.data(),
from_index->id_map.data());
Ptr new_ee(new FaissExecutionEngine<IndexTrait>(index->data(), location));
ExecutionEnginePtr new_ee(new FaissExecutionEngine(index->data(), location, build_index_type_, raw_index_type_));
new_ee->Serialize();
return new_ee;
}
template<class IndexTrait>
Status FaissExecutionEngine<IndexTrait>::Search(long n,
Status FaissExecutionEngine::Search(long n,
const float *data,
long k,
float *distances,
......@@ -135,8 +136,7 @@ Status FaissExecutionEngine<IndexTrait>::Search(long n,
return Status::OK();
}
template<class IndexTrait>
Status FaissExecutionEngine<IndexTrait>::Cache() {
Status FaissExecutionEngine::Cache() {
zilliz::vecwise::cache::CpuCacheMgr::GetInstance(
)->InsertItem(location_, std::make_shared<Index>(pIndex_));
......
......@@ -19,50 +19,54 @@ namespace vecwise {
namespace engine {
template<class IndexTrait>
class FaissExecutionEngine : public ExecutionEngine<FaissExecutionEngine<IndexTrait>> {
class FaissExecutionEngine : public ExecutionEngine {
public:
using Ptr = std::shared_ptr<FaissExecutionEngine<IndexTrait>>;
FaissExecutionEngine(uint16_t dimension, const std::string& location);
FaissExecutionEngine(std::shared_ptr<faiss::Index> index, const std::string& location);
FaissExecutionEngine(uint16_t dimension,
const std::string& location,
const std::string& build_index_type,
const std::string& raw_index_type);
Status AddWithIds(const std::vector<float>& vectors,
const std::vector<long>& vector_ids);
FaissExecutionEngine(std::shared_ptr<faiss::Index> index,
const std::string& location,
const std::string& build_index_type,
const std::string& raw_index_type);
Status AddWithIds(long n, const float *xdata, const long *xids);
Status AddWithIds(long n, const float *xdata, const long *xids) override;
size_t Count() const;
size_t Count() const override;
size_t Size() const;
size_t Size() const override;
size_t PhysicalSize() const;
size_t Dimension() const override;
Status Serialize();
size_t PhysicalSize() const override;
Status Load();
Status Serialize() override;
Status Merge(const std::string& location);
Status Load() override;
Status Merge(const std::string& location) override;
Status Search(long n,
const float *data,
long k,
float *distances,
long *labels) const;
long *labels) const override;
Ptr BuildIndex(const std::string&);
ExecutionEnginePtr BuildIndex(const std::string&) override;
Status Cache();
Status Cache() override;
protected:
std::shared_ptr<faiss::Index> pIndex_;
std::string location_;
std::string build_index_type_;
std::string raw_index_type_;
};
} // namespace engine
} // namespace vecwise
} // namespace zilliz
#include "FaissExecutionEngine.inl"
......@@ -3,18 +3,24 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "Traits.h"
#pragma once
#include <easylogging++.h>
namespace zilliz {
namespace vecwise {
namespace engine {
const char* IVFIndexTrait::BuildIndexType = "IVF";
const char* IVFIndexTrait::RawIndexType = "IDMap,Flat";
#define ENGINE_DOMAIN_NAME "[ENGINE] "
#define ENGINE_ERROR_TEXT "ENGINE Error:"
const char* IDMapIndexTrait::BuildIndexType = "IDMap";
const char* IDMapIndexTrait::RawIndexType = "IDMap,Flat";
#define ENGINE_LOG_TRACE LOG(TRACE) << ENGINE_DOMAIN_NAME
#define ENGINE_LOG_DEBUG LOG(DEBUG) << ENGINE_DOMAIN_NAME
#define ENGINE_LOG_INFO LOG(INFO) << ENGINE_DOMAIN_NAME
#define ENGINE_LOG_WARNING LOG(WARNING) << ENGINE_DOMAIN_NAME
#define ENGINE_LOG_ERROR LOG(ERROR) << ENGINE_DOMAIN_NAME
#define ENGINE_LOG_FATAL LOG(FATAL) << ENGINE_DOMAIN_NAME
} // namespace engine
} // namespace vecwise
} // namespace sql
} // namespace zilliz
} // namespace server
......@@ -3,11 +3,10 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MemManager.h"
#include "Meta.h"
#include "MetaConsts.h"
#include "EngineFactory.h"
#include "metrics/Metrics.h"
#include <iostream>
......@@ -19,59 +18,53 @@ namespace zilliz {
namespace vecwise {
namespace engine {
template<typename EngineT>
MemVectors<EngineT>::MemVectors(const std::shared_ptr<meta::Meta>& meta_ptr,
MemVectors::MemVectors(const std::shared_ptr<meta::Meta>& meta_ptr,
const meta::TableFileSchema& schema, const Options& options)
: pMeta_(meta_ptr),
options_(options),
schema_(schema),
pIdGenerator_(new SimpleIDGenerator()),
pEE_(new EngineT(schema_.dimension, schema_.location)) {
pEE_(EngineFactory::Build(schema_.dimension_, schema_.location_, (EngineType)schema_.engine_type_)) {
}
template<typename EngineT>
void MemVectors<EngineT>::Add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) {
void MemVectors::Add(size_t n_, const float* vectors_, IDNumbers& vector_ids_) {
pIdGenerator_->GetNextIDNumbers(n_, vector_ids_);
pEE_->AddWithIds(n_, vectors_, vector_ids_.data());
}
template<typename EngineT>
size_t MemVectors<EngineT>::Total() const {
size_t MemVectors::Total() const {
return pEE_->Count();
}
template<typename EngineT>
size_t MemVectors<EngineT>::ApproximateSize() const {
size_t MemVectors::ApproximateSize() const {
return pEE_->Size();
}
template<typename EngineT>
Status MemVectors<EngineT>::Serialize(std::string& table_id) {
table_id = schema_.table_id;
Status MemVectors::Serialize(std::string& table_id) {
table_id = schema_.table_id_;
auto size = ApproximateSize();
auto start_time = METRICS_NOW_TIME;
pEE_->Serialize();
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
schema_.size = size;
schema_.size_ = size;
server::Metrics::GetInstance().DiskStoreIOSpeedGaugeSet(size/total_time);
schema_.file_type = (size >= options_.index_trigger_size) ?
schema_.file_type_ = (size >= options_.index_trigger_size) ?
meta::TableFileSchema::TO_INDEX : meta::TableFileSchema::RAW;
auto status = pMeta_->UpdateTableFile(schema_);
LOG(DEBUG) << "New " << ((schema_.file_type == meta::TableFileSchema::RAW) ? "raw" : "to_index")
<< " file " << schema_.file_id << " of size " << pEE_->Size() / meta::M << " M";
LOG(DEBUG) << "New " << ((schema_.file_type_ == meta::TableFileSchema::RAW) ? "raw" : "to_index")
<< " file " << schema_.file_id_ << " of size " << pEE_->Size() / meta::M << " M";
pEE_->Cache();
return status;
}
template<typename EngineT>
MemVectors<EngineT>::~MemVectors() {
MemVectors::~MemVectors() {
if (pIdGenerator_ != nullptr) {
delete pIdGenerator_;
pIdGenerator_ = nullptr;
......@@ -81,9 +74,7 @@ MemVectors<EngineT>::~MemVectors() {
/*
* MemManager
*/
template<typename EngineT>
typename MemManager<EngineT>::MemVectorsPtr MemManager<EngineT>::GetMemByTable(
MemManager::MemVectorsPtr MemManager::GetMemByTable(
const std::string& table_id) {
auto memIt = memMap_.find(table_id);
if (memIt != memMap_.end()) {
......@@ -91,18 +82,17 @@ typename MemManager<EngineT>::MemVectorsPtr MemManager<EngineT>::GetMemByTable(
}
meta::TableFileSchema table_file;
table_file.table_id = table_id;
table_file.table_id_ = table_id;
auto status = pMeta_->CreateTableFile(table_file);
if (!status.ok()) {
return nullptr;
}
memMap_[table_id] = MemVectorsPtr(new MemVectors<EngineT>(pMeta_, table_file, options_));
memMap_[table_id] = MemVectorsPtr(new MemVectors(pMeta_, table_file, options_));
return memMap_[table_id];
}
template<typename EngineT>
Status MemManager<EngineT>::InsertVectors(const std::string& table_id_,
Status MemManager::InsertVectors(const std::string& table_id_,
size_t n_,
const float* vectors_,
IDNumbers& vector_ids_) {
......@@ -110,8 +100,7 @@ Status MemManager<EngineT>::InsertVectors(const std::string& table_id_,
return InsertVectorsNoLock(table_id_, n_, vectors_, vector_ids_);
}
template<typename EngineT>
Status MemManager<EngineT>::InsertVectorsNoLock(const std::string& table_id,
Status MemManager::InsertVectorsNoLock(const std::string& table_id,
size_t n,
const float* vectors,
IDNumbers& vector_ids) {
......@@ -124,8 +113,7 @@ Status MemManager<EngineT>::InsertVectorsNoLock(const std::string& table_id,
return Status::OK();
}
template<typename EngineT>
Status MemManager<EngineT>::ToImmutable() {
Status MemManager::ToImmutable() {
std::unique_lock<std::mutex> lock(mutex_);
for (auto& kv: memMap_) {
immMems_.push_back(kv.second);
......@@ -135,8 +123,7 @@ Status MemManager<EngineT>::ToImmutable() {
return Status::OK();
}
template<typename EngineT>
Status MemManager<EngineT>::Serialize(std::vector<std::string>& table_ids) {
Status MemManager::Serialize(std::vector<std::string>& table_ids) {
ToImmutable();
std::unique_lock<std::mutex> lock(serialization_mtx_);
std::string table_id;
......
......@@ -5,6 +5,7 @@
******************************************************************************/
#pragma once
#include "ExecutionEngine.h"
#include "IDGenerator.h"
#include "Status.h"
#include "Meta.h"
......@@ -23,12 +24,10 @@ namespace meta {
class Meta;
}
template <typename EngineT>
class MemVectors {
public:
using EnginePtr = typename EngineT::Ptr;
using MetaPtr = meta::Meta::Ptr;
using Ptr = std::shared_ptr<MemVectors<EngineT>>;
using Ptr = std::shared_ptr<MemVectors>;
explicit MemVectors(const std::shared_ptr<meta::Meta>&,
const meta::TableFileSchema&, const Options&);
......@@ -43,7 +42,7 @@ public:
~MemVectors();
const std::string& Location() const { return schema_.location; }
const std::string& Location() const { return schema_.location_; }
private:
MemVectors() = delete;
......@@ -54,18 +53,17 @@ private:
Options options_;
meta::TableFileSchema schema_;
IDGenerator* pIdGenerator_;
EnginePtr pEE_;
ExecutionEnginePtr pEE_;
}; // MemVectors
template<typename EngineT>
class MemManager {
public:
using MetaPtr = meta::Meta::Ptr;
using MemVectorsPtr = typename MemVectors<EngineT>::Ptr;
using Ptr = std::shared_ptr<MemManager<EngineT>>;
using MemVectorsPtr = typename MemVectors::Ptr;
using Ptr = std::shared_ptr<MemManager>;
MemManager(const std::shared_ptr<meta::Meta>& meta, const Options& options)
: pMeta_(meta), options_(options) {}
......@@ -96,4 +94,3 @@ private:
} // namespace engine
} // namespace vecwise
} // namespace zilliz
#include "MemManager.inl"
......@@ -5,6 +5,8 @@
******************************************************************************/
#pragma once
#include "ExecutionEngine.h"
#include <vector>
#include <map>
#include <string>
......@@ -19,12 +21,14 @@ const DateT EmptyDate = -1;
typedef std::vector<DateT> DatesT;
struct TableSchema {
size_t id;
std::string table_id;
size_t files_cnt = 0;
uint16_t dimension;
std::string location;
long created_on;
size_t id_;
std::string table_id_;
size_t files_cnt_ = 0;
uint16_t dimension_;
std::string location_;
long created_on_;
int engine_type_ = (int)EngineType::FAISS_IDMAP;
bool store_raw_data_ = false;
}; // TableSchema
struct TableFileSchema {
......@@ -36,16 +40,17 @@ struct TableFileSchema {
TO_DELETE,
} FILE_TYPE;
size_t id;
std::string table_id;
std::string file_id;
int file_type = NEW;
size_t size;
DateT date = EmptyDate;
uint16_t dimension;
std::string location;
long updated_time;
long created_on;
size_t id_;
std::string table_id_;
int engine_type_ = (int)EngineType::FAISS_IDMAP;
std::string file_id_;
int file_type_ = NEW;
size_t size_;
DateT date_ = EmptyDate;
uint16_t dimension_;
std::string location_;
long updated_time_;
long created_on_;
}; // TableFileSchema
typedef std::vector<TableFileSchema> TableFilesSchema;
......
......@@ -15,7 +15,7 @@ typedef long IDNumber;
typedef IDNumber* IDNumberPtr;
typedef std::vector<IDNumber> IDNumbers;
typedef std::vector<IDNumber> QueryResult;
typedef std::vector<std::pair<IDNumber, double>> QueryResult;
typedef std::vector<QueryResult> QueryResults;
......
......@@ -24,21 +24,21 @@ public:
SearchContext::Id2IndexMap index_files = search_context->GetIndexMap();
//some index loader alread exists
for(auto& loader : loader_list) {
if(index_files.find(loader->file_->id) != index_files.end()){
if(index_files.find(loader->file_->id_) != index_files.end()){
SERVER_LOG_INFO << "Append SearchContext to exist IndexLoaderContext";
index_files.erase(loader->file_->id);
index_files.erase(loader->file_->id_);
loader->search_contexts_.push_back(search_context);
}
}
//index_files still contains some index files, create new loader
for(auto& pair : index_files) {
SERVER_LOG_INFO << "Create new IndexLoaderContext for: " << pair.second->location;
SERVER_LOG_INFO << "Create new IndexLoaderContext for: " << pair.second->location_;
IndexLoaderContextPtr new_loader = std::make_shared<IndexLoaderContext>();
new_loader->search_contexts_.push_back(search_context);
new_loader->file_ = pair.second;
auto index = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(pair.second->location);
auto index = zilliz::vecwise::cache::CpuCacheMgr::GetInstance()->GetIndex(pair.second->location_);
if(index != nullptr) {
//if the index file has been in memory, increase its priority
loader_list.push_front(new_loader);
......
......@@ -26,13 +26,13 @@ SearchContext::SearchContext(uint64_t topk, uint64_t nq, const float* vectors)
bool
SearchContext::AddIndexFile(TableFileSchemaPtr& index_file) {
std::unique_lock <std::mutex> lock(mtx_);
if(index_file == nullptr || map_index_files_.find(index_file->id) != map_index_files_.end()) {
if(index_file == nullptr || map_index_files_.find(index_file->id_) != map_index_files_.end()) {
return false;
}
SERVER_LOG_INFO << "SearchContext " << identity_ << " add index file: " << index_file->id;
SERVER_LOG_INFO << "SearchContext " << identity_ << " add index file: " << index_file->id_;
map_index_files_[index_file->id] = index_file;
map_index_files_[index_file->id_] = index_file;
return true;
}
......
......@@ -31,8 +31,8 @@ public:
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
const Id2IndexMap& GetIndexMap() const { return map_index_files_; }
using Score2IdMap = std::map<float, int64_t>;
using ResultSet = std::vector<Score2IdMap>;
using Id2ScoreMap = std::vector<std::pair<int64_t, double>>;
using ResultSet = std::vector<Id2ScoreMap>;
const ResultSet& GetResult() const { return result_; }
ResultSet& GetResult() { return result_; }
......
......@@ -10,11 +10,50 @@
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "metrics/Metrics.h"
#include "db/EngineFactory.h"
namespace zilliz {
namespace vecwise {
namespace engine {
namespace {
void CollectFileMetrics(int file_type, size_t file_size) {
switch(file_type) {
case meta::TableFileSchema::RAW:
case meta::TableFileSchema::TO_INDEX: {
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
break;
}
default: {
server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
break;
}
}
}
void CollectDurationMetrics(int index_type, double total_time) {
switch(index_type) {
case meta::TableFileSchema::RAW: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
break;
}
case meta::TableFileSchema::TO_INDEX: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
break;
}
default: {
server::Metrics::GetInstance().SearchIndexDataDurationSecondsHistogramObserve(total_time);
break;
}
}
}
}
SearchScheduler::SearchScheduler()
: thread_pool_(2),
stopped_(true) {
......@@ -75,45 +114,27 @@ SearchScheduler::IndexLoadWorker() {
break;//exit
}
SERVER_LOG_INFO << "Loading index(" << context->file_->id << ") from location: " << context->file_->location;
SERVER_LOG_INFO << "Loading index(" << context->file_->id_ << ") from location: " << context->file_->location_;
server::TimeRecorder rc("Load index");
//load index
IndexEnginePtr index_ptr = std::make_shared<IndexClass>(context->file_->dimension, context->file_->location);
//step 1: load index
ExecutionEnginePtr index_ptr = EngineFactory::Build(context->file_->dimension_,
context->file_->location_,
(EngineType)context->file_->engine_type_);
index_ptr->Load();
rc.Record("load index file to memory");
size_t file_size = index_ptr->PhysicalSize();
LOG(DEBUG) << "Index file type " << context->file_->file_type << " Of Size: "
LOG(DEBUG) << "Index file type " << context->file_->file_type_ << " Of Size: "
<< file_size/(1024*1024) << " M";
//metric
switch(context->file_->file_type) {
case meta::TableFileSchema::RAW: {
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
break;
}
case meta::TableFileSchema::TO_INDEX: {
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
break;
}
default: {
server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
break;
}
}
CollectFileMetrics(context->file_->file_type_, file_size);
//put search task to another queue
SearchTaskPtr task_ptr = std::make_shared<SearchTaskClass>();
task_ptr->index_id_ = context->file_->id;
task_ptr->index_type_ = context->file_->file_type;
//step 2: put search task into another queue
SearchTaskPtr task_ptr = std::make_shared<SearchTask>();
task_ptr->index_id_ = context->file_->id_;
task_ptr->index_type_ = context->file_->file_type_;
task_ptr->index_engine_ = index_ptr;
task_ptr->search_contexts_.swap(context->search_contexts_);
search_queue.Put(task_ptr);
......@@ -140,20 +161,7 @@ SearchScheduler::SearchWorker() {
task_ptr->DoSearch();
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
switch(task_ptr->index_type_) {
case meta::TableFileSchema::RAW: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
break;
}
case meta::TableFileSchema::TO_INDEX: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
break;
}
default: {
server::Metrics::GetInstance().SearchIndexDataDurationSecondsHistogramObserve(total_time);
break;
}
}
CollectDurationMetrics(task_ptr->index_type_, total_time);
}
return true;
......
......@@ -3,8 +3,6 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "SearchTaskQueue.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
......@@ -21,12 +19,29 @@ void ClusterResult(const std::vector<long> &output_ids,
SearchContext::ResultSet &result_set) {
result_set.clear();
for (auto i = 0; i < nq; i++) {
SearchContext::Score2IdMap score2id;
SearchContext::Id2ScoreMap id_score;
for (auto k = 0; k < topk; k++) {
uint64_t index = i * nq + k;
score2id.insert(std::make_pair(output_distence[index], output_ids[index]));
id_score.push_back(std::make_pair(output_ids[index], output_distence[index]));
}
result_set.emplace_back(score2id);
result_set.emplace_back(id_score);
}
}
void MergeResult(SearchContext::Id2ScoreMap &score_src,
SearchContext::Id2ScoreMap &score_target,
uint64_t topk) {
for (auto& pair_src : score_src) {
for (auto iter = score_target.begin(); iter != score_target.end(); ++iter) {
if(pair_src.second > iter->second) {
score_target.insert(iter, pair_src);
}
}
}
//remove unused items
while (score_target.size() > topk) {
score_target.pop_back();
}
}
......@@ -44,18 +59,39 @@ void TopkResult(SearchContext::ResultSet &result_src,
}
for (size_t i = 0; i < result_src.size(); i++) {
SearchContext::Score2IdMap &score2id_src = result_src[i];
SearchContext::Score2IdMap &score2id_target = result_target[i];
for (auto iter = score2id_src.begin(); iter != score2id_src.end(); ++iter) {
score2id_target.insert(std::make_pair(iter->first, iter->second));
SearchContext::Id2ScoreMap &score_src = result_src[i];
SearchContext::Id2ScoreMap &score_target = result_target[i];
MergeResult(score_src, score_target, topk);
}
}
void CalcScore(uint64_t vector_count,
const float *vectors_data,
uint64_t dimension,
const SearchContext::ResultSet &result_src,
SearchContext::ResultSet &result_target) {
result_target.clear();
if(result_src.empty()){
return;
}
int vec_index = 0;
for(auto& result : result_src) {
const float * vec_data = vectors_data + vec_index*dimension;
double vec_len = 0;
for(uint64_t i = 0; i < dimension; i++) {
vec_len += vec_data[i]*vec_data[i];
}
vec_index++;
//remove unused items
while (score2id_target.size() > topk) {
score2id_target.erase(score2id_target.rbegin()->first);
SearchContext::Id2ScoreMap score_array;
for(auto& pair : result) {
score_array.push_back(std::make_pair(pair.first, (1 - pair.second/vec_len)*100.0));
}
result_target.emplace_back(score_array);
}
}
}
......@@ -70,8 +106,7 @@ SearchTaskQueue::GetInstance() {
return instance;
}
template<typename trait>
bool SearchTask<trait>::DoSearch() {
bool SearchTask::DoSearch() {
if(index_engine_ == nullptr) {
return false;
}
......@@ -81,10 +116,12 @@ bool SearchTask<trait>::DoSearch() {
std::vector<long> output_ids;
std::vector<float> output_distence;
for(auto& context : search_contexts_) {
//step 1: allocate memory
auto inner_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
output_ids.resize(inner_k*context->nq());
output_distence.resize(inner_k*context->nq());
//step 2: search
try {
index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
output_ids.data());
......@@ -96,11 +133,21 @@ bool SearchTask<trait>::DoSearch() {
rc.Record("do search");
//step 3: cluster result
SearchContext::ResultSet result_set;
ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set);
rc.Record("cluster result");
//step 4: pick up topk result
TopkResult(result_set, inner_k, context->GetResult());
rc.Record("reduce topk");
//step 5: calculate score between 0 ~ 100
CalcScore(context->nq(), context->vectors(), index_engine_->Dimension(), context->GetResult(), result_set);
context->GetResult().swap(result_set);
rc.Record("calculate score");
//step 6: notify to send result to client
context->IndexSearchDone(index_id_);
}
......
......@@ -7,8 +7,7 @@
#include "SearchContext.h"
#include "utils/BlockingQueue.h"
#include "../FaissExecutionEngine.h"
#include "../Traits.h"
#include "db/ExecutionEngine.h"
#include <memory>
......@@ -16,16 +15,6 @@ namespace zilliz {
namespace vecwise {
namespace engine {
#ifdef GPU_VERSION
using IndexTraitClass = IVFIndexTrait;
#else
using IndexTraitClass = IDMapIndexTrait;
#endif
using IndexClass = FaissExecutionEngine<IndexTraitClass>;
using IndexEnginePtr = std::shared_ptr<IndexClass>;
template <typename trait>
class SearchTask {
public:
bool DoSearch();
......@@ -33,12 +22,11 @@ public:
public:
size_t index_id_ = 0;
int index_type_ = 0; //for metrics
IndexEnginePtr index_engine_;
ExecutionEnginePtr index_engine_;
std::vector<SearchContextPtr> search_contexts_;
};
using SearchTaskClass = SearchTask<IndexTraitClass>;
using SearchTaskPtr = std::shared_ptr<SearchTaskClass>;
using SearchTaskPtr = std::shared_ptr<SearchTask>;
class SearchTaskQueue : public server::BlockingQueue<SearchTaskPtr> {
private:
......@@ -58,6 +46,4 @@ private:
}
}
}
#include "SearchTaskQueue.inl"
\ No newline at end of file
}
\ No newline at end of file
......@@ -16,9 +16,6 @@ namespace {
std::string GetTableName();
static const std::string TABLE_NAME = GetTableName();
static const std::string VECTOR_COLUMN_NAME = "face_vector";
static const std::string ID_COLUMN_NAME = "aid";
static const std::string CITY_COLUMN_NAME = "city";
static constexpr int64_t TABLE_DIMENSION = 512;
static constexpr int64_t TOTAL_ROW_COUNT = 100000;
static constexpr int64_t TOP_K = 10;
......@@ -29,9 +26,9 @@ namespace {
void PrintTableSchema(const megasearch::TableSchema& tb_schema) {
BLOCK_SPLITER
std::cout << "Table name: " << tb_schema.table_name << std::endl;
std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl;
std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl;
std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl;
std::cout << "Table index type: " << (int)tb_schema.index_type << std::endl;
std::cout << "Table dimension: " << tb_schema.dimension << std::endl;
std::cout << "Table store raw data: " << tb_schema.store_raw_vector << std::endl;
BLOCK_SPLITER
}
......@@ -58,9 +55,6 @@ namespace {
<< " search result:" << std::endl;
for(auto& item : result.query_result_arrays) {
std::cout << "\t" << std::to_string(item.id) << "\tscore:" << std::to_string(item.score);
for(auto& attribute : item.column_map) {
std::cout << "\t" << attribute.first << ":" << attribute.second;
}
std::cout << std::endl;
}
}
......@@ -88,73 +82,30 @@ namespace {
TableSchema BuildTableSchema() {
TableSchema tb_schema;
VectorColumn col1;
col1.name = VECTOR_COLUMN_NAME;
col1.dimension = TABLE_DIMENSION;
col1.store_raw_vector = true;
tb_schema.vector_column_array.emplace_back(col1);
Column col2 = {ColumnType::int8, ID_COLUMN_NAME};
tb_schema.attribute_column_array.emplace_back(col2);
Column col3 = {ColumnType::int16, CITY_COLUMN_NAME};
tb_schema.attribute_column_array.emplace_back(col3);
tb_schema.table_name = TABLE_NAME;
tb_schema.index_type = IndexType::gpu_ivfflat;
tb_schema.dimension = TABLE_DIMENSION;
tb_schema.store_raw_vector = true;
return tb_schema;
}
void BuildVectors(int64_t from, int64_t to,
std::vector<RowRecord>* vector_record_array,
std::vector<QueryRecord>* query_record_array) {
std::vector<RowRecord>& vector_record_array) {
if(to <= from){
return;
}
if(vector_record_array) {
vector_record_array->clear();
}
if(query_record_array) {
query_record_array->clear();
}
static const std::map<int64_t , std::string> CITY_MAP = {
{0, "Beijing"},
{1, "Shanhai"},
{2, "Hangzhou"},
{3, "Guangzhou"},
{4, "Shenzheng"},
{5, "Wuhan"},
{6, "Chengdu"},
{7, "Chongqin"},
{8, "Tianjing"},
{9, "Hongkong"},
};
vector_record_array.clear();
for (int64_t k = from; k < to; k++) {
std::vector<float> f_p;
f_p.resize(TABLE_DIMENSION);
RowRecord record;
record.data.resize(TABLE_DIMENSION);
for(int64_t i = 0; i < TABLE_DIMENSION; i++) {
f_p[i] = (float)(i + k);
}
if(vector_record_array) {
RowRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
record.attribute_map[ID_COLUMN_NAME] = std::to_string(k);
record.attribute_map[CITY_COLUMN_NAME] = CITY_MAP.at(k%CITY_MAP.size());
vector_record_array->emplace_back(record);
record.data[i] = (float)(i + k);
}
if(query_record_array) {
QueryRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
record.selected_column_array.push_back(ID_COLUMN_NAME);
record.selected_column_array.push_back(CITY_COLUMN_NAME);
query_record_array->emplace_back(record);
}
vector_record_array.emplace_back(record);
}
}
}
......@@ -205,7 +156,7 @@ ClientTest::Test(const std::string& address, const std::string& port) {
{//add vectors
std::vector<RowRecord> record_array;
BuildVectors(0, TOTAL_ROW_COUNT, &record_array, nullptr);
BuildVectors(0, TOTAL_ROW_COUNT, record_array);
std::vector<int64_t> record_ids;
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
std::cout << "AddVector function call status: " << stat.ToString() << std::endl;
......@@ -215,11 +166,12 @@ ClientTest::Test(const std::string& address, const std::string& port) {
{//search vectors
std::cout << "Waiting data persist. Sleep 10 seconds ..." << std::endl;
sleep(10);
std::vector<QueryRecord> record_array;
BuildVectors(SEARCH_TARGET, SEARCH_TARGET + 10, nullptr, &record_array);
std::vector<RowRecord> record_array;
BuildVectors(SEARCH_TARGET, SEARCH_TARGET + 10, record_array);
std::vector<Range> query_range_array;
std::vector<TopKQueryResult> topk_query_result_array;
Status stat = conn->SearchVector(TABLE_NAME, record_array, topk_query_result_array, TOP_K);
Status stat = conn->SearchVector(TABLE_NAME, record_array, query_range_array, TOP_K, topk_query_result_array);
std::cout << "SearchVector function call status: " << stat.ToString() << std::endl;
PrintSearchResult(topk_query_result_array);
}
......
......@@ -4,7 +4,6 @@
#include <string>
#include <vector>
#include <map>
#include <memory>
/** \brief MegaSearch SDK namespace
......@@ -12,129 +11,70 @@
namespace megasearch {
/**
* @brief Column Type
*/
enum class ColumnType {
invalid,
int8,
int16,
int32,
int64,
float32,
float64,
date,
vector
};
/**
* @brief Index Type
*/
enum class IndexType {
raw,
ivfflat
invalid = 0,
cpu_idmap,
gpu_ivfflat,
};
/**
* @brief Connect API parameter
*/
struct ConnectParam {
std::string ip_address; ///< Server IP address
std::string port; ///< Server PORT
};
/**
* @brief Table column description
*/
struct Column {
ColumnType type = ColumnType::invalid; ///< Column Type: enum ColumnType
std::string name; ///< Column name
};
/**
* @brief Table vector column description
*/
struct VectorColumn : public Column {
VectorColumn() { type = ColumnType::vector; }
int64_t dimension = 0; ///< Vector dimension
IndexType index_type = IndexType::raw; ///< Index type
bool store_raw_vector = false; ///< Is vector self stored in the table
std::string ip_address; ///< Server IP address
std::string port; ///< Server PORT
};
/**
* @brief Table Schema
*/
struct TableSchema {
std::string table_name; ///< Table name
std::vector<VectorColumn> vector_column_array; ///< Vector column description
std::vector<Column> attribute_column_array; ///< Columns description
std::vector<std::string> partition_column_name_array; ///< Partition column name
std::string table_name; ///< Table name
IndexType index_type = IndexType::invalid; ///< Index type
int64_t dimension = 0; ///< Vector dimension, must be a positive value
bool store_raw_vector = false; ///< Is vector raw data stored in the table
};
/**
* @brief Range information
* for DATE partition, the format is like: 'year-month-day'
*/
struct Range {
std::string start_value; ///< Range start
std::string end_value; ///< Range stop
};
/**
* @brief Create table partition parameters
*/
struct CreateTablePartitionParam {
std::string table_name; ///< Table name, vector/float32/float64 type column is not allowed for partition
std::string partition_name; ///< Partition name, created partition name
std::map<std::string, Range> range_map; ///< Column name to PartitionRange map
};
/**
* @brief Delete table partition parameters
*/
struct DeleteTablePartitionParam {
std::string table_name; ///< Table name
std::vector<std::string> partition_name_array; ///< Partition name array
std::string start_value; ///< Range start
std::string end_value; ///< Range stop
};
/**
* @brief Record inserted
*/
struct RowRecord {
std::map<std::string, std::vector<float>> vector_map; ///< Vector columns
std::map<std::string, std::string> attribute_map; ///< Other attribute columns
};
/**
* @brief Query record
*/
struct QueryRecord {
std::map<std::string, std::vector<float>> vector_map; ///< Query vectors
std::vector<std::string> selected_column_array; ///< Output column array
std::map<std::string, std::vector<Range>> partition_filter_column_map; ///< Range used to select partitions
std::vector<float> data; ///< Vector raw data
};
/**
* @brief Query result
*/
struct QueryResult {
int64_t id; ///< Output result
double score; ///< Vector similarity score: 0 ~ 100
std::map<std::string, std::string> column_map; ///< Other column
int64_t id; ///< Output result
double score; ///< Vector similarity score: 0 ~ 100
};
/**
* @brief TopK query result
*/
struct TopKQueryResult {
std::vector<QueryResult> query_result_arrays; ///< TopK query result
std::vector<QueryResult> query_result_arrays; ///< TopK query result
};
/**
* @brief SDK main class
*/
class Connection {
public:
public:
/**
* @brief CreateConnection
......@@ -228,30 +168,6 @@ class Connection {
virtual Status DeleteTable(const std::string &table_name) = 0;
/**
* @brief Create table partition
*
* This method is used to create table partition.
*
* @param param, use to provide partition information to be created.
*
* @return Indicate if table partition is created successfully.
*/
virtual Status CreateTablePartition(const CreateTablePartitionParam &param) = 0;
/**
* @brief Delete table partition
*
* This method is used to delete table partition.
*
* @param param, use to provide partition information to be deleted.
*
* @return Indicate if table partition is delete successfully.
*/
virtual Status DeleteTablePartition(const DeleteTablePartitionParam &param) = 0;
/**
* @brief Add vector to table
*
......@@ -264,8 +180,8 @@ class Connection {
* @return Indicate if vector array are inserted successfully
*/
virtual Status AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) = 0;
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) = 0;
/**
......@@ -275,15 +191,17 @@ class Connection {
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param topk_query_result_array, result array.
* @param query_range_array, time ranges, if not specified, will search in whole table
* @param topk, how many similarity vectors will be searched.
* @param topk_query_result_array, result array.
*
* @return Indicate if query is successful.
*/
virtual Status SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) = 0;
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
std::vector<TopKQueryResult> &topk_query_result_array) = 0;
/**
* @brief Show table description
......@@ -297,6 +215,18 @@ class Connection {
*/
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) = 0;
/**
* @brief Get table row count
*
* This method is used to get table row count.
*
* @param table_name, table's name.
* @param row_count, table total row count.
*
* @return Indicate if this operation is successful.
*/
virtual Status GetTableRowCount(const std::string &table_name, int64_t &row_count) = 0;
/**
* @brief Show all tables in database
*
......
......@@ -12,10 +12,13 @@ namespace megasearch {
*/
enum class StatusCode {
OK = 0,
Invalid = 1,
UnknownError = 2,
NotSupported = 3,
NotConnected = 4
// system error section
UnknownError = 1,
NotSupported,
NotConnected,
// function error section
InvalidAgument = 1000,
};
/**
......@@ -171,7 +174,7 @@ class Status {
*/
template<typename... Args>
static Status Invalid(Args &&... args) {
return Status(StatusCode::Invalid,
return Status(StatusCode::InvalidAgument,
MessageBuilder(std::forward<Args>(args)...));
}
......@@ -221,7 +224,7 @@ class Status {
* @return, if the status indicates invalid.
*
*/
bool IsInvalid() const { return code() == StatusCode::Invalid; }
bool IsInvalid() const { return code() == StatusCode::InvalidAgument; }
/**
* @brief IsUnknownError
......
......@@ -89,27 +89,9 @@ ClientProxy::CreateTable(const TableSchema &param) {
try {
thrift::TableSchema schema;
schema.__set_table_name(param.table_name);
std::vector<thrift::VectorColumn> vector_column_array;
for(auto& column : param.vector_column_array) {
thrift::VectorColumn col;
col.__set_dimension(column.dimension);
col.__set_index_type(ConvertUtil::IndexType2Str(column.index_type));
col.__set_store_raw_vector(column.store_raw_vector);
vector_column_array.emplace_back(col);
}
schema.__set_vector_column_array(vector_column_array);
std::vector<thrift::Column> attribute_column_array;
for(auto& column : param.attribute_column_array) {
thrift::Column col;
col.__set_name(col.name);
col.__set_type(col.type);
attribute_column_array.emplace_back(col);
}
schema.__set_attribute_column_array(attribute_column_array);
schema.__set_partition_column_name_array(param.partition_column_name_array);
schema.__set_index_type((int)param.index_type);
schema.__set_dimension(param.dimension);
schema.__set_store_raw_vector(param.store_raw_vector);
ClientPtr()->interface()->CreateTable(schema);
......@@ -120,55 +102,6 @@ ClientProxy::CreateTable(const TableSchema &param) {
return Status::OK();
}
Status
ClientProxy::CreateTablePartition(const CreateTablePartitionParam &param) {
if(!IsConnected()) {
return Status(StatusCode::NotConnected, "not connected to server");
}
try {
thrift::CreateTablePartitionParam partition_param;
partition_param.__set_table_name(param.table_name);
partition_param.__set_partition_name(param.partition_name);
std::map<std::string, thrift::Range> range_map;
for(auto& pair : param.range_map) {
thrift::Range range;
range.__set_start_value(pair.second.start_value);
range.__set_end_value(pair.second.end_value);
range_map.insert(std::make_pair(pair.first, range));
}
partition_param.__set_range_map(range_map);
ClientPtr()->interface()->CreateTablePartition(partition_param);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to create table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DeleteTablePartition(const DeleteTablePartitionParam &param) {
if(!IsConnected()) {
return Status(StatusCode::NotConnected, "not connected to server");
}
try {
thrift::DeleteTablePartitionParam partition_param;
partition_param.__set_table_name(param.table_name);
partition_param.__set_partition_name_array(param.partition_name_array);
ClientPtr()->interface()->DeleteTablePartition(partition_param);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to delete table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DeleteTable(const std::string &table_name) {
if(!IsConnected()) {
......@@ -197,17 +130,13 @@ ClientProxy::AddVector(const std::string &table_name,
std::vector<thrift::RowRecord> thrift_records;
for(auto& record : record_array) {
thrift::RowRecord thrift_record;
thrift_record.__set_attribute_map(record.attribute_map);
for(auto& pair : record.vector_map) {
size_t dim = pair.second.size();
std::string& thrift_vector = thrift_record.vector_map[pair.first];
thrift_vector.resize(dim * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
for (size_t i = 0; i < dim; i++) {
dbl[i] = (double) (pair.second[i]);
}
thrift_record.vector_data.resize(record.data.size() * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_record.vector_data.data()));
for (size_t i = 0; i < record.data.size(); i++) {
dbl[i] = (double) (record.data[i]);
}
thrift_records.emplace_back(thrift_record);
}
ClientPtr()->interface()->AddVector(id_array, table_name, thrift_records);
......@@ -221,33 +150,31 @@ ClientProxy::AddVector(const std::string &table_name,
Status
ClientProxy::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
std::vector<TopKQueryResult> &topk_query_result_array) {
if(!IsConnected()) {
return Status(StatusCode::NotConnected, "not connected to server");
}
try {
std::vector<thrift::QueryRecord> thrift_records;
std::vector<thrift::RowRecord> thrift_records;
for(auto& record : query_record_array) {
thrift::QueryRecord thrift_record;
thrift_record.__set_selected_column_array(record.selected_column_array);
for(auto& pair : record.vector_map) {
size_t dim = pair.second.size();
std::string& thrift_vector = thrift_record.vector_map[pair.first];
thrift_vector.resize(dim * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
for (size_t i = 0; i < dim; i++) {
dbl[i] = (double) (pair.second[i]);
}
thrift::RowRecord thrift_record;
thrift_record.vector_data.resize(record.data.size() * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_record.vector_data.data()));
for (size_t i = 0; i < record.data.size(); i++) {
dbl[i] = (double) (record.data[i]);
}
thrift_records.emplace_back(thrift_record);
}
std::vector<thrift::Range> thrift_ranges;
std::vector<thrift::TopKQueryResult> result_array;
ClientPtr()->interface()->SearchVector(result_array, table_name, thrift_records, topk);
ClientPtr()->interface()->SearchVector(result_array, table_name, thrift_records, thrift_ranges, topk);
for(auto& thrift_topk_result : result_array) {
TopKQueryResult result;
......@@ -255,7 +182,6 @@ ClientProxy::SearchVector(const std::string &table_name,
for(auto& thrift_query_result : thrift_topk_result.query_result_arrays) {
QueryResult query_result;
query_result.id = thrift_query_result.id;
query_result.column_map = thrift_query_result.column_map;
query_result.score = thrift_query_result.score;
result.query_result_arrays.emplace_back(query_result);
}
......@@ -281,27 +207,26 @@ ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_sch
ClientPtr()->interface()->DescribeTable(thrift_schema, table_name);
table_schema.table_name = thrift_schema.table_name;
table_schema.partition_column_name_array = thrift_schema.partition_column_name_array;
table_schema.index_type = (IndexType)thrift_schema.index_type;
for(auto& thrift_col : thrift_schema.attribute_column_array) {
Column col;
col.name = col.name;
col.type = col.type;
table_schema.attribute_column_array.emplace_back(col);
}
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to describe table: " + std::string(ex.what()));
}
for(auto& thrift_col : thrift_schema.vector_column_array) {
VectorColumn col;
col.store_raw_vector = thrift_col.store_raw_vector;
col.index_type = ConvertUtil::Str2IndexType(thrift_col.index_type);
col.dimension = thrift_col.dimension;
col.name = thrift_col.base.name;
col.type = (ColumnType)thrift_col.base.type;
table_schema.vector_column_array.emplace_back(col);
}
return Status::OK();
}
Status
ClientProxy::GetTableRowCount(const std::string &table_name, int64_t &row_count) {
if(!IsConnected()) {
return Status(StatusCode::NotConnected, "not connected to server");
}
try {
row_count = ClientPtr()->interface()->GetTableRowCount(table_name);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to describe table: " + std::string(ex.what()));
return Status(StatusCode::UnknownError, "failed to show tables: " + std::string(ex.what()));
}
return Status::OK();
......
......@@ -25,21 +25,20 @@ public:
virtual Status DeleteTable(const std::string &table_name) override;
virtual Status CreateTablePartition(const CreateTablePartitionParam &param) override;
virtual Status DeleteTablePartition(const DeleteTablePartitionParam &param) override;
virtual Status AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) override;
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status GetTableRowCount(const std::string &table_name, int64_t &row_count) override;
virtual Status ShowTables(std::vector<std::string> &table_array) override;
virtual std::string ClientVersion() const override;
......
......@@ -58,7 +58,7 @@ ThriftClient::Connect(const std::string& address, int32_t port, const std::strin
protocol_ptr.reset(new TCompactProtocol(transport_ptr));
} else {
//CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
return Status(StatusCode::Invalid, "unsupported protocol");
return Status(StatusCode::InvalidAgument, "unsupported protocol");
}
transport_ptr->open();
......
......@@ -55,16 +55,6 @@ ConnectionImpl::CreateTable(const TableSchema &param) {
return client_proxy_->CreateTable(param);
}
Status
ConnectionImpl::CreateTablePartition(const CreateTablePartitionParam &param) {
return client_proxy_->CreateTablePartition(param);
}
Status
ConnectionImpl::DeleteTablePartition(const DeleteTablePartitionParam &param) {
return client_proxy_->DeleteTablePartition(param);
}
Status
ConnectionImpl::DeleteTable(const std::string &table_name) {
return client_proxy_->DeleteTable(table_name);
......@@ -79,10 +69,11 @@ ConnectionImpl::AddVector(const std::string &table_name,
Status
ConnectionImpl::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
return client_proxy_->SearchVector(table_name, query_record_array, topk_query_result_array, topk);
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
std::vector<TopKQueryResult> &topk_query_result_array) {
return client_proxy_->SearchVector(table_name, query_record_array, query_range_array, topk, topk_query_result_array);
}
Status
......@@ -90,6 +81,11 @@ ConnectionImpl::DescribeTable(const std::string &table_name, TableSchema &table_
return client_proxy_->DescribeTable(table_name, table_schema);
}
Status
ConnectionImpl::GetTableRowCount(const std::string &table_name, int64_t &row_count) {
return client_proxy_->GetTableRowCount(table_name, row_count);
}
Status
ConnectionImpl::ShowTables(std::vector<std::string> &table_array) {
return client_proxy_->ShowTables(table_array);
......
......@@ -27,21 +27,20 @@ public:
virtual Status DeleteTable(const std::string &table_name) override;
virtual Status CreateTablePartition(const CreateTablePartitionParam &param) override;
virtual Status DeleteTablePartition(const DeleteTablePartitionParam &param) override;
virtual Status AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) override;
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status GetTableRowCount(const std::string &table_name, int64_t &row_count) override;
virtual Status ShowTables(std::vector<std::string> &table_array) override;
virtual std::string ClientVersion() const override;
......
......@@ -95,7 +95,7 @@ std::string Status::CodeAsString() const {
switch (code()) {
case StatusCode::OK: type = "OK";
break;
case StatusCode::Invalid: type = "Invalid";
case StatusCode::InvalidAgument: type = "Invalid agument";
break;
case StatusCode::UnknownError: type = "Unknown error";
break;
......
......@@ -15,13 +15,13 @@ static const std::string INDEX_IVFFLAT = "ivfflat";
std::string ConvertUtil::IndexType2Str(megasearch::IndexType index) {
static const std::map<megasearch::IndexType, std::string> s_index2str = {
{megasearch::IndexType::raw, INDEX_RAW},
{megasearch::IndexType::ivfflat, INDEX_IVFFLAT}
{megasearch::IndexType::cpu_idmap, INDEX_RAW},
{megasearch::IndexType::gpu_ivfflat, INDEX_IVFFLAT}
};
const auto& iter = s_index2str.find(index);
if(iter == s_index2str.end()) {
throw Exception(StatusCode::Invalid, "Invalid index type");
throw Exception(StatusCode::InvalidAgument, "Invalid index type");
}
return iter->second;
......@@ -29,13 +29,13 @@ std::string ConvertUtil::IndexType2Str(megasearch::IndexType index) {
megasearch::IndexType ConvertUtil::Str2IndexType(const std::string& type) {
static const std::map<std::string, megasearch::IndexType> s_str2index = {
{INDEX_RAW, megasearch::IndexType::raw},
{INDEX_IVFFLAT, megasearch::IndexType::ivfflat}
{INDEX_RAW, megasearch::IndexType::cpu_idmap},
{INDEX_IVFFLAT, megasearch::IndexType::gpu_ivfflat}
};
const auto& iter = s_str2index.find(type);
if(iter == s_str2index.end()) {
throw Exception(StatusCode::Invalid, "Invalid index type");
throw Exception(StatusCode::InvalidAgument, "Invalid index type");
}
return iter->second;
......
......@@ -30,18 +30,6 @@ MegasearchServiceHandler::DeleteTable(const std::string &table_name) {
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::CreateTablePartition(const thrift::CreateTablePartitionParam &param) {
BaseTaskPtr task_ptr = CreateTablePartitionTask::Create(param);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::DeleteTablePartition(const thrift::DeleteTablePartitionParam &param) {
BaseTaskPtr task_ptr = DeleteTablePartitionTask::Create(param);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::AddVector(std::vector<int64_t> &_return,
const std::string &table_name,
......@@ -51,11 +39,12 @@ MegasearchServiceHandler::AddVector(std::vector<int64_t> &_return,
}
void
MegasearchServiceHandler::SearchVector(std::vector<thrift::TopKQueryResult> &_return,
const std::string &table_name,
const std::vector<thrift::QueryRecord> &query_record_array,
const int64_t topk) {
BaseTaskPtr task_ptr = SearchVectorTask::Create(table_name, query_record_array, topk, _return);
MegasearchServiceHandler::SearchVector(std::vector<megasearch::thrift::TopKQueryResult> & _return,
const std::string& table_name,
const std::vector<megasearch::thrift::RowRecord> & query_record_array,
const std::vector<megasearch::thrift::Range> & query_range_array,
const int64_t topk) {
BaseTaskPtr task_ptr = SearchVectorTask::Create(table_name, query_record_array, query_range_array, topk, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
......@@ -65,6 +54,18 @@ MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std:
MegasearchScheduler::ExecTask(task_ptr);
}
int64_t
MegasearchServiceHandler::GetTableRowCount(const std::string& table_name) {
int64_t row_count = 0;
{
BaseTaskPtr task_ptr = GetTableRowCountTask::Create(table_name, row_count);
MegasearchScheduler::ExecTask(task_ptr);
task_ptr->WaitToFinish();
}
return row_count;
}
void
MegasearchServiceHandler::ShowTables(std::vector<std::string> &_return) {
BaseTaskPtr task_ptr = ShowTablesTask::Create(_return);
......
......@@ -19,15 +19,15 @@ public:
MegasearchServiceHandler();
/**
* @brief Create table method
*
* This method is used to create table
*
* @param param, use to provide table information to be created.
*
*
* @param param
*/
* @brief Create table method
*
* This method is used to create table
*
* @param param, use to provide table information to be created.
*
*
* @param param
*/
void CreateTable(const megasearch::thrift::TableSchema& param);
/**
......@@ -42,30 +42,6 @@ public:
*/
void DeleteTable(const std::string& table_name);
/**
* @brief Create table partition
*
* This method is used to create table partition.
*
* @param param, use to provide partition information to be created.
*
*
* @param param
*/
void CreateTablePartition(const megasearch::thrift::CreateTablePartitionParam& param);
/**
* @brief Delete table partition
*
* This method is used to delete table partition.
*
* @param param, use to provide partition information to be deleted.
*
*
* @param param
*/
void DeleteTablePartition(const megasearch::thrift::DeleteTablePartitionParam& param);
/**
* @brief Add vector array to table
*
......@@ -90,25 +66,28 @@ public:
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param query_range_array, optional ranges for conditional search. If not specified, search whole table
* @param topk, how many similarity vectors will be searched.
*
* @return query result array.
*
* @param table_name
* @param query_record_array
* @param query_range_array
* @param topk
*/
void SearchVector(std::vector<megasearch::thrift::TopKQueryResult> & _return,
const std::string& table_name,
const std::vector<megasearch::thrift::QueryRecord> & query_record_array,
const std::vector<megasearch::thrift::RowRecord> & query_record_array,
const std::vector<megasearch::thrift::Range> & query_range_array,
const int64_t topk);
/**
* @brief Show table information
* @brief Get table schema
*
* This method is used to show table information.
* This method is used to get table schema.
*
* @param table_name, which table is show.
* @param table_name, target table name.
*
* @return table schema
*
......@@ -116,6 +95,19 @@ public:
*/
void DescribeTable(megasearch::thrift::TableSchema& _return, const std::string& table_name);
/**
* @brief Get table row count
*
* This method is used to get table row count.
*
* @param table_name, target table name.
*
* @return table row count
*
* @param table_name
*/
int64_t GetTableRowCount(const std::string& table_name);
/**
* @brief List all tables in database
*
......
......@@ -24,7 +24,7 @@ namespace {
{SERVER_FILE_NOT_FOUND, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_NOT_IMPLEMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_BLOCKING_QUEUE_EMPTY, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_GROUP_NOT_EXIST, thrift::ErrorCode::TABLE_NOT_EXISTS},
{SERVER_TABLE_NOT_EXIST, thrift::ErrorCode::TABLE_NOT_EXISTS},
{SERVER_INVALID_TIME_RANGE, thrift::ErrorCode::ILLEGAL_RANGE},
{SERVER_INVALID_VECTOR_DIMENSION, thrift::ErrorCode::ILLEGAL_DIMENSION},
};
......@@ -40,7 +40,7 @@ namespace {
{SERVER_FILE_NOT_FOUND, "file not found"},
{SERVER_NOT_IMPLEMENT, "not implemented"},
{SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"},
{SERVER_GROUP_NOT_EXIST, "group not exist"},
{SERVER_TABLE_NOT_EXIST, "table not exist"},
{SERVER_INVALID_TIME_RANGE, "invalid time range"},
{SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"},
};
......
......@@ -5,15 +5,13 @@
******************************************************************************/
#include "MegasearchTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h"
#include "db/DB.h"
#include "db/Env.h"
#include "db/Meta.h"
#include "version.h"
namespace zilliz {
namespace vecwise {
......@@ -64,9 +62,18 @@ namespace {
return db_wrapper.DB();
}
ThreadPool& GetThreadPool() {
static ThreadPool pool(6);
return pool;
engine::EngineType EngineType(int type) {
static std::map<int, engine::EngineType> map_type = {
{0, engine::EngineType::INVALID},
{1, engine::EngineType::FAISS_IDMAP},
{2, engine::EngineType::FAISS_IVFFLAT},
};
if(map_type.find(type) == map_type.end()) {
return engine::EngineType::INVALID;
}
return map_type[type];
}
}
......@@ -85,16 +92,20 @@ ServerError CreateTableTask::OnExecute() {
TimeRecorder rc("CreateTableTask");
try {
if(schema_.vector_column_array.empty()) {
if(schema_.table_name.empty() || schema_.dimension == 0 || schema_.index_type == 0) {
return SERVER_INVALID_ARGUMENT;
}
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
//step 1: construct table schema
engine::meta::TableSchema table_info;
table_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
table_info.table_id = schema_.table_name;
table_info.dimension_ = (uint16_t)schema_.dimension;
table_info.table_id_ = schema_.table_name;
table_info.engine_type_ = (int)EngineType(schema_.index_type);
table_info.store_raw_data_ = schema_.store_raw_vector;
//step 2: create table
engine::Status stat = DB()->CreateTable(table_info);
if(!stat.ok()) {//could exist
if(!stat.ok()) {//table could exist
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return SERVER_SUCCESS;
......@@ -129,10 +140,10 @@ ServerError DescribeTableTask::OnExecute() {
try {
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
table_info.table_id_ = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_code_ = SERVER_TABLE_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
......@@ -168,46 +179,6 @@ ServerError DeleteTableTask::OnExecute() {
error_msg_ = "delete table not implemented";
SERVER_LOG_ERROR << error_msg_;
IVecIdMapper::GetInstance()->DeleteGroup(table_name_);
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTablePartitionTask::CreateTablePartitionTask(const thrift::CreateTablePartitionParam &param)
: BaseTask(DDL_DML_TASK_GROUP),
param_(param) {
}
BaseTaskPtr CreateTablePartitionTask::Create(const thrift::CreateTablePartitionParam &param) {
return std::shared_ptr<BaseTask>(new CreateTablePartitionTask(param));
}
ServerError CreateTablePartitionTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "create table partition not implemented";
SERVER_LOG_ERROR << error_msg_;
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTablePartitionTask::DeleteTablePartitionTask(const thrift::DeleteTablePartitionParam &param)
: BaseTask(DDL_DML_TASK_GROUP),
param_(param) {
}
BaseTaskPtr DeleteTablePartitionTask::Create(const thrift::DeleteTablePartitionParam &param) {
return std::shared_ptr<BaseTask>(new DeleteTablePartitionTask(param));
}
ServerError DeleteTablePartitionTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "delete table partition not implemented";
SERVER_LOG_ERROR << error_msg_;
return SERVER_NOT_IMPLEMENT;
}
......@@ -223,7 +194,6 @@ BaseTaskPtr ShowTablesTask::Create(std::vector<std::string>& tables) {
}
ServerError ShowTablesTask::OnExecute() {
IVecIdMapper::GetInstance()->AllGroups(tables_);
return SERVER_SUCCESS;
}
......@@ -253,31 +223,33 @@ ServerError AddVectorTask::OnExecute() {
return SERVER_SUCCESS;
}
//step 1: check table existence
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
table_info.table_id_ = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_code_ = SERVER_TABLE_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
rc.Record("get group info");
rc.Record("check validation");
//step 2: prepare float data
uint64_t vec_count = (uint64_t)record_array_.size();
uint64_t group_dim = table_info.dimension;
uint64_t group_dim = table_info.dimension_;
std::vector<float> vec_f;
vec_f.resize(vec_count*group_dim);//allocate enough memory
for(uint64_t i = 0; i < vec_count; i++) {
const auto& record = record_array_[i];
if(record.vector_map.empty()) {
if(record.vector_data.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "No vector provided in record";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size()/sizeof(double);//how many double value?
uint64_t vec_dim = record.vector_data.size()/sizeof(double);//how many double value?
if(vec_dim != group_dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_dim;
......@@ -287,7 +259,7 @@ ServerError AddVectorTask::OnExecute() {
}
//convert double array to float array(thrift has no float type)
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
const double* d_p = reinterpret_cast<const double*>(record.vector_data.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
......@@ -295,6 +267,7 @@ ServerError AddVectorTask::OnExecute() {
rc.Record("prepare vectors data");
//step 3: insert vectors
stat = DB()->InsertVectors(table_name_, vec_count, vec_f.data(), record_ids_);
rc.Record("add vectors to engine");
if(!stat.ok()) {
......@@ -309,22 +282,8 @@ ServerError AddVectorTask::OnExecute() {
return SERVER_UNEXPECTED_ERROR;
}
//persist attributes
for(uint64_t i = 0; i < vec_count; i++) {
const auto &record = record_array_[i];
//any attributes?
if(record.attribute_map.empty()) {
continue;
}
std::string nid = std::to_string(record_ids_[i]);
std::string attrib_str;
AttributeSerializer::Encode(record.attribute_map, attrib_str);
IVecIdMapper::GetInstance()->Put(nid, attrib_str, table_name_);
}
rc.Record("persist vector attributes");
rc.Record("do insert");
rc.Elapse("totally cost");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
......@@ -338,28 +297,33 @@ ServerError AddVectorTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string& table_name,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<megasearch::thrift::Range> & query_range_array,
const int64_t top_k,
const std::vector<thrift::QueryRecord>& record_array,
std::vector<thrift::TopKQueryResult>& result_array)
: BaseTask(DQL_TASK_GROUP),
table_name_(table_name),
record_array_(query_record_array),
range_array_(query_range_array),
top_k_(top_k),
record_array_(record_array),
result_array_(result_array) {
}
BaseTaskPtr SearchVectorTask::Create(const std::string& table_name,
const std::vector<thrift::QueryRecord>& record_array,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<megasearch::thrift::Range> & query_range_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name, top_k, record_array, result_array));
return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name,
query_record_array, query_range_array, top_k, result_array));
}
ServerError SearchVectorTask::OnExecute() {
try {
TimeRecorder rc("SearchVectorTask");
//step 1: check validation
if(top_k_ <= 0 || record_array_.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Invalid topk value, or query record array is empty";
......@@ -367,40 +331,44 @@ ServerError SearchVectorTask::OnExecute() {
return error_code_;
}
//step 2: check table existence
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
table_info.table_id_ = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_code_ = SERVER_TABLE_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
rc.Record("check validation");
//step 3: prepare float data
std::vector<float> vec_f;
uint64_t record_count = (uint64_t)record_array_.size();
vec_f.resize(record_count*table_info.dimension);
vec_f.resize(record_count*table_info.dimension_);
for(uint64_t i = 0; i < record_array_.size(); i++) {
const auto& record = record_array_[i];
if (record.vector_map.empty()) {
if (record.vector_data.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Query record has no vector";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
if (vec_dim != table_info.dimension) {
uint64_t vec_dim = record.vector_data.size() / sizeof(double);//how many double value?
if (vec_dim != table_info.dimension_) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << table_info.dimension;
<< " vs. group dimension:" << table_info.dimension_;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
//convert double array to float array(thrift has no float type)
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
const double* d_p = reinterpret_cast<const double*>(record.vector_data.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
......@@ -408,6 +376,8 @@ ServerError SearchVectorTask::OnExecute() {
rc.Record("prepare vector data");
//step 4: search vectors
std::vector<DB_DATE> dates;
engine::QueryResults results;
stat = DB()->Query(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
......@@ -422,32 +392,18 @@ ServerError SearchVectorTask::OnExecute() {
return SERVER_UNEXPECTED_ERROR;
}
//construct result array
rc.Record("do search");
//step 5: construct result array
for(uint64_t i = 0; i < record_count; i++) {
auto& result = results[i];
const auto& record = record_array_[i];
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
for(auto& pair : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
//need get attributes?
if(record.selected_column_array.empty()) {
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
continue;
}
std::string nid = std::to_string(id);
std::string attrib_str;
IVecIdMapper::GetInstance()->Get(nid, attrib_str, table_name_);
AttribMap attrib_map;
AttributeSerializer::Decode(attrib_str, attrib_map);
for(auto& attribute : record.selected_column_array) {
thrift_result.column_map[attribute] = attrib_map[attribute];
}
thrift_result.__set_id(pair.first);
thrift_result.__set_score(pair.second);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
......@@ -466,6 +422,32 @@ ServerError SearchVectorTask::OnExecute() {
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
GetTableRowCountTask::GetTableRowCountTask(const std::string& table_name, int64_t& row_count)
: BaseTask(DQL_TASK_GROUP),
table_name_(table_name),
row_count_(row_count) {
}
BaseTaskPtr GetTableRowCountTask::Create(const std::string& table_name, int64_t& row_count) {
return std::shared_ptr<BaseTask>(new GetTableRowCountTask(table_name, row_count));
}
ServerError GetTableRowCountTask::OnExecute() {
if(table_name_.empty()) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = "Table name cannot be empty";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "Not implemented";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
PingTask::PingTask(const std::string& cmd, std::string& result)
: BaseTask(PING_TASK_GROUP),
......@@ -480,7 +462,7 @@ BaseTaskPtr PingTask::Create(const std::string& cmd, std::string& result) {
ServerError PingTask::OnExecute() {
if(cmd_ == "version") {
result_ = "v1.2.0";//currently hardcode
result_ = MEGASEARCH_VERSION;
}
return SERVER_SUCCESS;
......
......@@ -65,36 +65,6 @@ private:
std::string table_name_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CreateTablePartitionTask : public BaseTask {
public:
static BaseTaskPtr Create(const thrift::CreateTablePartitionParam &param);
protected:
CreateTablePartitionTask(const thrift::CreateTablePartitionParam &param);
ServerError OnExecute() override;
private:
const thrift::CreateTablePartitionParam &param_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DeleteTablePartitionTask : public BaseTask {
public:
static BaseTaskPtr Create(const thrift::DeleteTablePartitionParam &param);
protected:
DeleteTablePartitionTask(const thrift::DeleteTablePartitionParam &param);
ServerError OnExecute() override;
private:
const thrift::DeleteTablePartitionParam &param_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class ShowTablesTask : public BaseTask {
public:
......@@ -133,14 +103,16 @@ private:
class SearchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name,
const std::vector<thrift::QueryRecord>& record_array,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<megasearch::thrift::Range> & query_range_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array);
protected:
SearchVectorTask(const std::string& table_name,
const std::vector<thrift::RowRecord> & query_record_array,
const std::vector<megasearch::thrift::Range> & query_range_array,
const int64_t top_k,
const std::vector<thrift::QueryRecord>& record_array,
std::vector<thrift::TopKQueryResult>& result_array);
ServerError OnExecute() override;
......@@ -148,10 +120,26 @@ protected:
private:
std::string table_name_;
int64_t top_k_;
const std::vector<thrift::QueryRecord>& record_array_;
const std::vector<thrift::RowRecord>& record_array_;
const std::vector<megasearch::thrift::Range>& range_array_;
std::vector<thrift::TopKQueryResult>& result_array_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class GetTableRowCountTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name, int64_t& row_count);
protected:
GetTableRowCountTask(const std::string& table_name, int64_t& row_count);
ServerError OnExecute() override;
private:
std::string table_name_;
int64_t& row_count_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class PingTask : public BaseTask {
public:
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "RocksIdMapper.h"
#include "ServerConfig.h"
#include "utils/Log.h"
#include "utils/CommonUtil.h"
#include "rocksdb/db.h"
#include "rocksdb/slice.h"
#include "rocksdb/options.h"
#include <exception>
namespace zilliz {
namespace vecwise {
namespace server {
static const std::string ROCKSDB_DEFAULT_GROUP = "default";
RocksIdMapper::RocksIdMapper()
: db_(nullptr) {
OpenDb();
}
RocksIdMapper::~RocksIdMapper() {
CloseDb();
}
void RocksIdMapper::OpenDb() {
std::lock_guard<std::mutex> lck(db_mutex_);
if(db_) {
return;
}
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
std::string db_path = config.GetValue(CONFIG_DB_PATH);
db_path += "/id_mapping";
CommonUtil::CreateDirectory(db_path);
rocksdb::Options options;
// Optimize RocksDB. This is the easiest way to get RocksDB to perform well
options.IncreaseParallelism();
options.OptimizeLevelStyleCompaction();
// create the DB if it's not already present
options.create_if_missing = true;
options.max_open_files = config.GetInt32Value(CONFIG_DB_IDMAPPER_MAX_FILE, 512);
//load column families
std::vector<std::string> column_names;
rocksdb::Status s = rocksdb::DB::ListColumnFamilies(options, db_path, &column_names);
if (!s.ok()) {
SERVER_LOG_ERROR << "ID mapper failed to initialize:" << s.ToString();
}
if(column_names.empty()) {
column_names.push_back("default");
}
SERVER_LOG_INFO << "ID mapper has " << std::to_string(column_names.size()) << " groups";
std::vector<rocksdb::ColumnFamilyDescriptor> column_families;
for(auto& column_name : column_names) {
rocksdb::ColumnFamilyDescriptor desc;
desc.name = column_name;
column_families.emplace_back(desc);
}
// open DB
std::vector<rocksdb::ColumnFamilyHandle*> column_handles;
s = rocksdb::DB::Open(options, db_path, column_families, &column_handles, &db_);
if(!s.ok()) {
SERVER_LOG_ERROR << "ID mapper failed to initialize:" << s.ToString();
db_ = nullptr;
}
column_handles_.clear();
for(auto handler : column_handles) {
column_handles_.insert(std::make_pair(handler->GetName(), handler));
}
}
void RocksIdMapper::CloseDb() {
std::lock_guard<std::mutex> lck(db_mutex_);
for(auto& iter : column_handles_) {
delete iter.second;
}
column_handles_.clear();
if(db_) {
db_->Close();
delete db_;
}
}
ServerError RocksIdMapper::AddGroup(const std::string& group) {
std::lock_guard<std::mutex> lck(db_mutex_);
return AddGroupInternal(group);
}
bool RocksIdMapper::IsGroupExist(const std::string& group) const {
std::lock_guard<std::mutex> lck(db_mutex_);
return IsGroupExistInternal(group);
}
ServerError RocksIdMapper::AllGroups(std::vector<std::string>& groups) const {
groups.clear();
std::lock_guard<std::mutex> lck(db_mutex_);
for(auto& pair : column_handles_) {
if(pair.first == ROCKSDB_DEFAULT_GROUP) {
continue;
}
groups.push_back(pair.first);
}
return SERVER_SUCCESS;
}
ServerError RocksIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
std::lock_guard<std::mutex> lck(db_mutex_);
return PutInternal(nid, sid, group);
}
ServerError RocksIdMapper::Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group) {
if(nid.size() != sid.size()) {
return SERVER_INVALID_ARGUMENT;
}
std::lock_guard<std::mutex> lck(db_mutex_);
ServerError err = SERVER_SUCCESS;
for(size_t i = 0; i < nid.size(); i++) {
err = PutInternal(nid[i], sid[i], group);
if(err != SERVER_SUCCESS) {
return err;
}
}
return err;
}
ServerError RocksIdMapper::Get(const std::string& nid, std::string& sid, const std::string& group) const {
std::lock_guard<std::mutex> lck(db_mutex_);
return GetInternal(nid, sid, group);
}
ServerError RocksIdMapper::Get(const std::vector<std::string>& nid, std::vector<std::string>& sid, const std::string& group) const {
sid.clear();
std::lock_guard<std::mutex> lck(db_mutex_);
ServerError err = SERVER_SUCCESS;
for(size_t i = 0; i < nid.size(); i++) {
std::string str_id;
ServerError temp_err = GetInternal(nid[i], str_id, group);
if(temp_err != SERVER_SUCCESS) {
sid.push_back("");
SERVER_LOG_ERROR << "ID mapper failed to get id: " << nid[i];
err = temp_err;
continue;
}
sid.push_back(str_id);
}
return err;
}
ServerError RocksIdMapper::Delete(const std::string& nid, const std::string& group) {
std::lock_guard<std::mutex> lck(db_mutex_);
return DeleteInternal(nid, group);
}
ServerError RocksIdMapper::DeleteGroup(const std::string& group) {
std::lock_guard<std::mutex> lck(db_mutex_);
return DeleteGroupInternal(group);
}
//internal methods(whitout lock)
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
ServerError RocksIdMapper::AddGroupInternal(const std::string& group) {
if(!IsGroupExistInternal(group)) {
if(db_ == nullptr) {
return SERVER_NULL_POINTER;
}
try {//add group
rocksdb::ColumnFamilyHandle *cfh = nullptr;
rocksdb::Status s = db_->CreateColumnFamily(rocksdb::ColumnFamilyOptions(), group, &cfh);
if (!s.ok()) {
SERVER_LOG_ERROR << "ID mapper failed to create group:" << s.ToString();
return SERVER_UNEXPECTED_ERROR;
} else {
column_handles_.insert(std::make_pair(group, cfh));
}
} catch(std::exception& ex) {
SERVER_LOG_ERROR << "ID mapper failed to create group: " << ex.what();
return SERVER_UNEXPECTED_ERROR;
}
}
return SERVER_SUCCESS;
}
bool RocksIdMapper::IsGroupExistInternal(const std::string& group) const {
std::string group_name = group;
if(group_name.empty()){
group_name = ROCKSDB_DEFAULT_GROUP;
}
return (column_handles_.count(group_name) > 0 && column_handles_[group_name] != nullptr);
}
ServerError RocksIdMapper::PutInternal(const std::string& nid, const std::string& sid, const std::string& group) {
if(db_ == nullptr) {
return SERVER_NULL_POINTER;
}
rocksdb::Slice key(nid);
rocksdb::Slice value(sid);
if(group.empty()) {//to default group
rocksdb::Status s = db_->Put(rocksdb::WriteOptions(), key, value);
if (!s.ok()) {
SERVER_LOG_ERROR << "ID mapper failed to put:" << s.ToString();
return SERVER_UNEXPECTED_ERROR;
}
} else {
//try create group
if(AddGroupInternal(group) != SERVER_SUCCESS){
return SERVER_UNEXPECTED_ERROR;
}
rocksdb::ColumnFamilyHandle *cfh = column_handles_[group];
rocksdb::Status s = db_->Put(rocksdb::WriteOptions(), cfh, key, value);
if (!s.ok()) {
SERVER_LOG_ERROR << "ID mapper failed to put:" << s.ToString();
return SERVER_UNEXPECTED_ERROR;
}
}
return SERVER_SUCCESS;
}
ServerError RocksIdMapper::GetInternal(const std::string& nid, std::string& sid, const std::string& group) const {
sid = "";
if(db_ == nullptr) {
return SERVER_NULL_POINTER;
}
rocksdb::ColumnFamilyHandle *cfh = nullptr;
if(column_handles_.count(group) != 0) {
cfh = column_handles_.at(group);
}
rocksdb::Slice key(nid);
rocksdb::Status s;
if(cfh){
s = db_->Get(rocksdb::ReadOptions(), cfh, key, &sid);
} else {
s = db_->Get(rocksdb::ReadOptions(), key, &sid);
}
if(!s.ok()) {
SERVER_LOG_ERROR << "ID mapper failed to get:" << s.ToString();
return SERVER_UNEXPECTED_ERROR;
}
return SERVER_SUCCESS;
}
ServerError RocksIdMapper::DeleteInternal(const std::string& nid, const std::string& group) {
if(db_ == nullptr) {
return SERVER_NULL_POINTER;
}
rocksdb::ColumnFamilyHandle *cfh = nullptr;
if(column_handles_.count(group) != 0) {
cfh = column_handles_.at(group);
}
rocksdb::Slice key(nid);
rocksdb::Status s;
if(cfh){
s = db_->Delete(rocksdb::WriteOptions(), cfh, key);
} else {
s = db_->Delete(rocksdb::WriteOptions(), key);
}
if(!s.ok()) {
SERVER_LOG_ERROR << "ID mapper failed to delete:" << s.ToString();
return SERVER_UNEXPECTED_ERROR;
}
return SERVER_SUCCESS;
}
ServerError RocksIdMapper::DeleteGroupInternal(const std::string& group) {
if(db_ == nullptr) {
return SERVER_NULL_POINTER;
}
rocksdb::ColumnFamilyHandle *cfh = nullptr;
if(column_handles_.count(group) != 0) {
cfh = column_handles_.at(group);
}
if(cfh) {
db_->DropColumnFamily(cfh);
db_->DestroyColumnFamilyHandle(cfh);
column_handles_.erase(group);
}
return SERVER_SUCCESS;
}
}
}
}
\ No newline at end of file
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "utils/Error.h"
#include "VecIdMapper.h"
#include <string>
#include <vector>
#include <unordered_map>
#include <mutex>
namespace rocksdb {
class DB;
class ColumnFamilyHandle;
}
namespace zilliz {
namespace vecwise {
namespace server {
class RocksIdMapper : public IVecIdMapper{
public:
RocksIdMapper();
~RocksIdMapper();
ServerError AddGroup(const std::string& group) override;
bool IsGroupExist(const std::string& group) const override;
ServerError AllGroups(std::vector<std::string>& groups) const override;
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;
ServerError Get(const std::string& nid, std::string& sid, const std::string& group = "") const override;
ServerError Get(const std::vector<std::string>& nid, std::vector<std::string>& sid, const std::string& group = "") const override;
ServerError Delete(const std::string& nid, const std::string& group = "") override;
ServerError DeleteGroup(const std::string& group) override;
private:
void OpenDb();
void CloseDb();
ServerError AddGroupInternal(const std::string& group);
bool IsGroupExistInternal(const std::string& group) const;
ServerError PutInternal(const std::string& nid, const std::string& sid, const std::string& group);
ServerError GetInternal(const std::string& nid, std::string& sid, const std::string& group) const;
ServerError DeleteInternal(const std::string& nid, const std::string& group);
ServerError DeleteGroupInternal(const std::string& group);
private:
rocksdb::DB* db_;
mutable std::unordered_map<std::string, rocksdb::ColumnFamilyHandle*> column_handles_;
mutable std::mutex db_mutex_;
};
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecIdMapper.h"
#include "RocksIdMapper.h"
#include "ServerConfig.h"
#include "utils/Log.h"
#include "utils/CommonUtil.h"
#include "rocksdb/db.h"
#include "rocksdb/slice.h"
#include "rocksdb/options.h"
#include <exception>
#include <unordered_map>
namespace zilliz {
namespace vecwise {
namespace server {
IVecIdMapper* IVecIdMapper::GetInstance() {
#if 0
static SimpleIdMapper s_mapper;
return &s_mapper;
#else
static RocksIdMapper s_mapper;
return &s_mapper;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SimpleIdMapper : public IVecIdMapper{
public:
SimpleIdMapper();
~SimpleIdMapper();
ServerError AddGroup(const std::string& group) override;
bool IsGroupExist(const std::string& group) const override;
ServerError AllGroups(std::vector<std::string>& groups) const override;
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;
ServerError Get(const std::string& nid, std::string& sid, const std::string& group = "") const override;
ServerError Get(const std::vector<std::string>& nid, std::vector<std::string>& sid, const std::string& group = "") const override;
ServerError Delete(const std::string& nid, const std::string& group = "") override;
ServerError DeleteGroup(const std::string& group) override;
private:
using ID_MAPPING = std::unordered_map<std::string, std::string>;
mutable std::unordered_map<std::string, ID_MAPPING> id_groups_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SimpleIdMapper::SimpleIdMapper() {
}
SimpleIdMapper::~SimpleIdMapper() {
}
ServerError
SimpleIdMapper::AddGroup(const std::string& group) {
if(id_groups_.count(group) == 0) {
id_groups_.insert(std::make_pair(group, ID_MAPPING()));
}
}
//not thread-safe
bool
SimpleIdMapper::IsGroupExist(const std::string& group) const {
return id_groups_.count(group) > 0;
}
ServerError SimpleIdMapper::AllGroups(std::vector<std::string>& groups) const {
groups.clear();
for(auto& pair : id_groups_) {
groups.push_back(pair.first);
}
return SERVER_SUCCESS;
}
//not thread-safe
ServerError SimpleIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
ID_MAPPING& mapping = id_groups_[group];
mapping[nid] = sid;
return SERVER_SUCCESS;
}
//not thread-safe
ServerError SimpleIdMapper::Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group) {
if(nid.size() != sid.size()) {
return SERVER_INVALID_ARGUMENT;
}
ID_MAPPING& mapping = id_groups_[group];
for(size_t i = 0; i < nid.size(); i++) {
mapping[nid[i]] = sid[i];
}
return SERVER_SUCCESS;
}
//not thread-safe
ServerError SimpleIdMapper::Get(const std::string& nid, std::string& sid, const std::string& group) const {
ID_MAPPING& mapping = id_groups_[group];
auto iter = mapping.find(nid);
if(iter == mapping.end()) {
return SERVER_INVALID_ARGUMENT;
}
sid = iter->second;
return SERVER_SUCCESS;
}
//not thread-safe
ServerError SimpleIdMapper::Get(const std::vector<std::string>& nid, std::vector<std::string>& sid, const std::string& group) const {
sid.clear();
ID_MAPPING& mapping = id_groups_[group];
ServerError err = SERVER_SUCCESS;
for(size_t i = 0; i < nid.size(); i++) {
auto iter = mapping.find(nid[i]);
if(iter == mapping.end()) {
sid.push_back("");
SERVER_LOG_ERROR << "ID mapper failed to find id: " << nid[i];
err = SERVER_INVALID_ARGUMENT;
continue;
}
sid.push_back(iter->second);
}
return err;
}
//not thread-safe
ServerError SimpleIdMapper::Delete(const std::string& nid, const std::string& group) {
ID_MAPPING& mapping = id_groups_[group];
mapping.erase(nid);
return SERVER_SUCCESS;
}
//not thread-safe
ServerError SimpleIdMapper::DeleteGroup(const std::string& group) {
id_groups_.erase(group);
return SERVER_SUCCESS;
}
}
}
}
\ No newline at end of file
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "utils/Error.h"
#include <string>
#include <vector>
namespace zilliz {
namespace vecwise {
namespace server {
class IVecIdMapper {
public:
static IVecIdMapper* GetInstance();
virtual ~IVecIdMapper(){}
virtual ServerError AddGroup(const std::string& group) = 0;
virtual bool IsGroupExist(const std::string& group) const = 0;
virtual ServerError AllGroups(std::vector<std::string>& groups) const = 0;
virtual ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") = 0;
virtual ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") = 0;
virtual ServerError Get(const std::string& nid, std::string& sid, const std::string& group = "") const = 0;
//NOTE: the 'sid' will be cleared at begin of the function
virtual ServerError Get(const std::vector<std::string>& nid, std::vector<std::string>& sid, const std::string& group = "") const = 0;
virtual ServerError Delete(const std::string& nid, const std::string& group = "") = 0;
virtual ServerError DeleteGroup(const std::string& group) = 0;
};
}
}
}
此差异已折叠。
......@@ -31,7 +31,7 @@ constexpr ServerError SERVER_INVALID_ARGUMENT = ToGlobalServerErrorCode(0x004);
constexpr ServerError SERVER_FILE_NOT_FOUND = ToGlobalServerErrorCode(0x005);
constexpr ServerError SERVER_NOT_IMPLEMENT = ToGlobalServerErrorCode(0x006);
constexpr ServerError SERVER_BLOCKING_QUEUE_EMPTY = ToGlobalServerErrorCode(0x007);
constexpr ServerError SERVER_GROUP_NOT_EXIST = ToGlobalServerErrorCode(0x008);
constexpr ServerError SERVER_TABLE_NOT_EXIST = ToGlobalServerErrorCode(0x008);
constexpr ServerError SERVER_INVALID_TIME_RANGE = ToGlobalServerErrorCode(0x009);
constexpr ServerError SERVER_INVALID_VECTOR_DIMENSION = ToGlobalServerErrorCode(0x00a);
constexpr ServerError SERVER_LICENSE_VALIDATION_FAIL = ToGlobalServerErrorCode(0x00b);
......
此差异已折叠。
此差异已折叠。
......@@ -38,7 +38,7 @@ engine::Options DBTest::GetOptions() {
void DBTest::SetUp() {
InitLog();
auto options = GetOptions();
db_ = engine::DBFactory::Build(options, "Faiss,IDMap");
db_ = engine::DBFactory::Build(options);
}
void DBTest::TearDown() {
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册