提交 8d2b16a6 编写于 作者: Z zhiru

Merge remote-tracking branch 'upstream/branch-0.4.0' into cmake


Former-commit-id: 4a6d01614e1ec35b25523be67a5d7f7676f75b17
......@@ -17,11 +17,16 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-430 - Search no result if index created with FLAT
- MS-443 - Create index hang again
- MS-436 - Delete vectors failed if index created with index_type: IVF_FLAT/IVF_SQ8
- MS-450 - server hang after run stop_server.sh
- MS-449 - Add vectors twice success, once with ids, the other no ids
- MS-450 - server hang after run stop_server.sh
- MS-458 - Keep building index for one file when no gpu resource
- MS-461 - Mysql meta unittest failed
- MS-462 - Run milvus server twices, should display error
- MS-463 - Search timeout
- MS-467 - mysql db test failed
- MS-470 - Drop index success, which table not created
- MS-471 - code coverage run failed
- MS-492 - Drop index failed if index have been created with index_type: FLAT
## Improvement
- MS-327 - Clean code for milvus
......@@ -80,6 +85,9 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-455 - Distribute tasks by minimal cost in scheduler
- MS-460 - Put transport speed as weight when choosing neighbour to execute task
- MS-459 - Add cache for pick function in tasktable
- MS-482 - Change search stream transport to unary in grpc
- MS-487 - Define metric type in CreateTable
- MS-488 - Improve code format in scheduler
## New Feature
- MS-343 - Implement ResourceMgr
......
......@@ -74,6 +74,9 @@ resource_config:
device_id: 0
enable_loader: true
enable_executor: true
gpu_resource_num: 2
pinned_memory: 300
temp_memory: 300
# gtx1660:
# type: GPU
......
......@@ -12,7 +12,7 @@ FILE_INFO_OUTPUT_NEW="output_new.info"
DIR_LCOV_OUTPUT="lcov_out"
DIR_GCNO="cmake_build"
DIR_UNITTEST="milvus/bin"
DIR_UNITTEST="milvus/unittest"
MYSQL_USER_NAME=root
MYSQL_PASSWORD=Fantast1c
......@@ -77,6 +77,7 @@ for test in `ls ${DIR_UNITTEST}`; do
# run unittest
./${DIR_UNITTEST}/${test} "${args}"
if [ $? -ne 0 ]; then
echo ${args}
echo ${DIR_UNITTEST}/${test} "run failed"
fi
done
......@@ -93,6 +94,7 @@ ${LCOV_CMD} -r "${FILE_INFO_OUTPUT}" -o "${FILE_INFO_OUTPUT_NEW}" \
"/usr/*" \
"*/boost/*" \
"*/cmake_build/*_ep-prefix/*" \
"src/core/cmake_build*" \
# gen html report
${LCOV_GEN_CMD} "${FILE_INFO_OUTPUT_NEW}" --output-directory ${DIR_LCOV_OUTPUT}/
\ No newline at end of file
......@@ -261,7 +261,8 @@ else()
message(STATUS "FAISS URL = ${FAISS_SOURCE_URL}")
endif()
# set(FAISS_MD5 "a589663865a8558205533c8ac414278c")
set(FAISS_MD5 "31167ecbd1903fec600dc4ac00b9be9e")
# set(FAISS_MD5 "57da9c4f599cc8fa4260488b1c96e1cc") # commit-id 6dbdf75987c34a2c853bd172ea0d384feea8358c
set(FAISS_MD5 "21deb1c708490ca40ecb899122c01403") # commit-id 643e48f479637fd947e7b93fa4ca72b38ecc9a39
if(DEFINED ENV{KNOWHERE_ARROW_URL})
set(ARROW_SOURCE_URL "$ENV{KNOWHERE_ARROW_URL}")
......
......@@ -26,6 +26,9 @@ class KnowhereException : public std::exception {
};
#define KNOHWERE_ERROR_MSG(MSG)\
printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
#define KNOWHERE_THROW_MSG(MSG)\
do {\
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__);\
......
......@@ -8,6 +8,18 @@
namespace zilliz {
namespace knowhere {
struct Resource {
Resource(std::shared_ptr<faiss::gpu::StandardGpuResources> &r): faiss_res(r) {
static int64_t global_id = 0;
id = global_id++;
}
std::shared_ptr<faiss::gpu::StandardGpuResources> faiss_res;
int64_t id;
};
using ResPtr = std::shared_ptr<Resource>;
using ResWPtr = std::weak_ptr<Resource>;
class FaissGpuResourceMgr {
public:
struct DeviceParams {
......@@ -17,14 +29,11 @@ class FaissGpuResourceMgr {
};
public:
using ResPtr = std::shared_ptr<faiss::gpu::StandardGpuResources>;
using ResWPtr = std::weak_ptr<faiss::gpu::StandardGpuResources>;
static FaissGpuResourceMgr &
GetInstance();
void
AllocateTempMem(ResPtr &res, const int64_t& device_id, const int64_t& size);
AllocateTempMem(ResPtr &resource, const int64_t& device_id, const int64_t& size);
void
InitDevice(int64_t device_id,
......@@ -32,12 +41,23 @@ class FaissGpuResourceMgr {
int64_t temp_mem_size = 0,
int64_t res_num = 2);
void InitResource();
void
InitResource();
// allocate gpu memory invoke by build or copy_to_gpu
ResPtr
GetRes(const int64_t &device_id, const int64_t& alloc_size = 0);
ResPtr GetRes(const int64_t &device_id, const int64_t& alloc_size = 0);
// allocate gpu memory before search
// this func will return True if the device is idle and exists an idle resource.
bool
GetRes(const int64_t& device_id, ResPtr &res, const int64_t& alloc_size = 0);
void MoveToInuse(const int64_t &device_id, const ResPtr& res);
void MoveToIdle(const int64_t &device_id, const ResPtr& res);
void
MoveToInuse(const int64_t &device_id, const ResPtr& res);
void
MoveToIdle(const int64_t &device_id, const ResPtr& res);
protected:
bool is_init = false;
......@@ -50,23 +70,24 @@ class FaissGpuResourceMgr {
class ResScope {
public:
ResScope(const int64_t device_id,std::shared_ptr<faiss::gpu::StandardGpuResources> &res) : resource(res), device_id(device_id) {
ResScope(const int64_t device_id, ResPtr &res) : resource(res), device_id(device_id) {
FaissGpuResourceMgr::GetInstance().MoveToInuse(device_id, resource);
}
~ResScope() {
resource->noTempMemory();
//resource->faiss_res->noTempMemory();
FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource);
}
private:
std::shared_ptr<faiss::gpu::StandardGpuResources> resource;
ResPtr resource;
int64_t device_id;
};
class GPUIndex {
public:
explicit GPUIndex(const int &device_id) : gpu_id_(device_id) {};
explicit GPUIndex(const int &device_id) : gpu_id_(device_id) {}
GPUIndex(const int& device_id, ResPtr resource): gpu_id_(device_id), res_(std::move(resource)){}
virtual VectorIndexPtr CopyGpuToCpu(const Config &config) = 0;
virtual VectorIndexPtr CopyGpuToGpu(const int64_t &device_id, const Config &config) = 0;
......@@ -76,13 +97,14 @@ class GPUIndex {
protected:
int64_t gpu_id_;
ResPtr res_ = nullptr;
};
class GPUIVF : public IVF, public GPUIndex {
public:
explicit GPUIVF(const int &device_id) : IVF(), GPUIndex(device_id) {}
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t &device_id)
: IVF(std::move(index)), GPUIndex(device_id) {};
explicit GPUIVF(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr &resource)
: IVF(std::move(index)), GPUIndex(device_id, resource) {};
IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override;
void set_index_model(IndexModelPtr model) override;
//DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
......@@ -107,7 +129,8 @@ class GPUIVF : public IVF, public GPUIndex {
class GPUIVFSQ : public GPUIVF {
public:
explicit GPUIVFSQ(const int &device_id) : GPUIVF(device_id) {}
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t& device_id) : GPUIVF(std::move(index),device_id) {};
explicit GPUIVFSQ(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr &resource)
: GPUIVF(std::move(index), device_id, resource) {};
IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override;
public:
......
......@@ -39,8 +39,8 @@ using IDMAPPtr = std::shared_ptr<IDMAP>;
class GPUIDMAP : public IDMAP, public GPUIndex {
public:
explicit GPUIDMAP(std::shared_ptr<faiss::Index> index, const int64_t &device_id)
: IDMAP(std::move(index)), GPUIndex(device_id) {}
explicit GPUIDMAP(std::shared_ptr<faiss::Index> index, const int64_t &device_id, ResPtr& res)
: IDMAP(std::move(index)), GPUIndex(device_id, res) {}
VectorIndexPtr CopyGpuToCpu(const Config &config) override;
float *GetRawVectors() override;
......
......@@ -31,26 +31,31 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) {
GETTENSOR(dataset)
auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device);
ResScope rs(gpu_device, res);
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_device;
faiss::gpu::GpuIndexIVFFlat device_index(res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
return std::make_shared<IVFIndexModel>(host_index);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device);
if (temp_resource != nullptr) {
ResScope rs(gpu_device, temp_resource );
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = gpu_device;
faiss::gpu::GpuIndexIVFFlat device_index(temp_resource->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
return std::make_shared<IVFIndexModel>(host_index);
} else {
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");
}
}
void GPUIVF::set_index_model(IndexModelPtr model) {
std::lock_guard<std::mutex> lk(mutex_);
auto host_index = std::static_pointer_cast<IVFIndexModel>(model);
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
ResScope rs(gpu_id_, res);
auto device_index = faiss::gpu::index_cpu_to_gpu(res.get(), gpu_id_, host_index->index_.get());
if (auto gpures = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
ResScope rs(gpu_id_, gpures);
res_ = gpures;
auto device_index = faiss::gpu::index_cpu_to_gpu(res_->faiss_res.get(), gpu_id_, host_index->index_.get());
index_.reset(device_index);
} else {
KNOWHERE_THROW_MSG("load index model error, can't get gpu_resource");
......@@ -94,9 +99,10 @@ void GPUIVF::LoadImpl(const BinarySet &index_binary) {
faiss::Index *index = faiss::read_index(&reader);
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
ResScope rs(gpu_id_, res);
auto device_index = faiss::gpu::index_cpu_to_gpu(res.get(), gpu_id_, index);
if (auto temp_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) {
ResScope rs(gpu_id_, temp_res);
res_ = temp_res;
auto device_index = faiss::gpu::index_cpu_to_gpu(res_->faiss_res.get(), gpu_id_, index);
index_.reset(device_index);
} else {
KNOWHERE_THROW_MSG("Load error, can't get gpu resource");
......@@ -123,14 +129,21 @@ void GPUIVF::search_impl(int64_t n,
float *distances,
int64_t *labels,
const Config &cfg) {
if (auto device_index = std::static_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
// todo: allocate search memory
auto nprobe = cfg.get_with_default("nprobe", size_t(1));
// TODO(linxj): allocate mem
if (FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_, res_)) {
ResScope rs(gpu_id_, res_);
if (auto device_index = std::static_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
auto nprobe = cfg.get_with_default("nprobe", size_t(1));
std::lock_guard<std::mutex> lk(mutex_);
device_index->setNumProbes(nprobe);
device_index->search(n, (float *) data, k, distances, labels);
std::lock_guard<std::mutex> lk(mutex_);
device_index->setNumProbes(nprobe);
device_index->search(n, (float *) data, k, distances, labels);
}
} else {
KNOWHERE_THROW_MSG("search can't get gpu resource");
}
}
VectorIndexPtr GPUIVF::CopyGpuToCpu(const Config &config) {
......@@ -165,6 +178,7 @@ IndexModelPtr GPUIVFPQ::Train(const DatasetPtr &dataset, const Config &config) {
GETTENSOR(dataset)
// TODO(linxj): set device here.
// TODO(linxj): set gpu resource here.
faiss::gpu::StandardGpuResources res;
faiss::gpu::GpuIndexIVFPQ device_index(&res, dim, nlist, M, nbits, metric_type);
device_index.train(rows, (float *) p_data);
......@@ -202,17 +216,23 @@ IndexModelPtr GPUIVFSQ::Train(const DatasetPtr &dataset, const Config &config) {
index_type << "IVF" << nlist << "," << "SQ" << nbits;
auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type);
faiss::gpu::StandardGpuResources res;
auto device_index = faiss::gpu::index_cpu_to_gpu(&res, gpu_num, build_index);
device_index->train(rows, (float *) p_data);
auto temp_resource = FaissGpuResourceMgr::GetInstance().GetRes(gpu_num);
if (temp_resource != nullptr) {
ResScope rs(gpu_num, temp_resource );
auto device_index = faiss::gpu::index_cpu_to_gpu(temp_resource->faiss_res.get(), gpu_num, build_index);
device_index->train(rows, (float *) p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
delete device_index;
delete build_index;
delete device_index;
delete build_index;
return std::make_shared<IVFIndexModel>(host_index);
return std::make_shared<IVFIndexModel>(host_index);
}
else {
KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource");
}
}
VectorIndexPtr GPUIVFSQ::CopyGpuToCpu(const Config &config) {
......@@ -231,16 +251,16 @@ FaissGpuResourceMgr &FaissGpuResourceMgr::GetInstance() {
return instance;
}
void FaissGpuResourceMgr::AllocateTempMem(std::shared_ptr<faiss::gpu::StandardGpuResources> &res,
void FaissGpuResourceMgr::AllocateTempMem(ResPtr &resource,
const int64_t &device_id,
const int64_t &size) {
if (size) {
res->setTempMemory(size);
resource->faiss_res->setTempMemory(size);
}
else {
auto search = devices_params_.find(device_id);
if (search != devices_params_.end()) {
res->setTempMemory(search->second.temp_mem_size);
resource->faiss_res->setTempMemory(search->second.temp_mem_size);
}
// else do nothing. allocate when use.
}
......@@ -264,14 +284,19 @@ void FaissGpuResourceMgr::InitResource() {
for (int i = 0; i < device.second.resource_num; ++i) {
auto res = std::make_shared<faiss::gpu::StandardGpuResources>();
res->noTempMemory();
resource_vec.push_back(res);
// TODO(linxj): enable set pinned memory
//res->noTempMemory();
auto res_wrapper = std::make_shared<Resource>(res);
AllocateTempMem(res_wrapper, device.first, 0);
resource_vec.emplace_back(res_wrapper);
}
}
}
std::shared_ptr<faiss::gpu::StandardGpuResources> FaissGpuResourceMgr::GetRes(const int64_t &device_id,
const int64_t &alloc_size) {
ResPtr FaissGpuResourceMgr::GetRes(const int64_t &device_id,
const int64_t &alloc_size) {
std::lock_guard<std::mutex> lk(mutex_);
if (!is_init) {
......@@ -282,21 +307,48 @@ std::shared_ptr<faiss::gpu::StandardGpuResources> FaissGpuResourceMgr::GetRes(co
auto search = idle_.find(device_id);
if (search != idle_.end()) {
auto res = search->second.back();
AllocateTempMem(res, device_id, alloc_size);
//AllocateTempMem(res, device_id, alloc_size);
search->second.pop_back();
return res;
}
return nullptr;
}
bool FaissGpuResourceMgr::GetRes(const int64_t &device_id,
ResPtr &res,
const int64_t &alloc_size) {
std::lock_guard<std::mutex> lk(mutex_);
if (!is_init) {
InitResource();
is_init = true;
}
auto search = idle_.find(device_id);
if (search != idle_.end()) {
auto &res_vec = search->second;
for (auto it = res_vec.cbegin(); it != res_vec.cend(); ++it) {
if ((*it)->id == res->id) {
//AllocateTempMem(res, device_id, alloc_size);
res_vec.erase(it);
return true;
}
}
}
// else
return false;
}
void FaissGpuResourceMgr::MoveToInuse(const int64_t &device_id, const std::shared_ptr<faiss::gpu::StandardGpuResources> &res) {
void FaissGpuResourceMgr::MoveToInuse(const int64_t &device_id, const ResPtr &res) {
std::lock_guard<std::mutex> lk(mutex_);
in_use_[device_id].push_back(res);
}
void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const std::shared_ptr<faiss::gpu::StandardGpuResources> &res) {
void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const ResPtr &res) {
std::lock_guard<std::mutex> lk(mutex_);
idle_[device_id].push_back(res);
auto it = idle_[device_id].begin();
idle_[device_id].insert(it, res);
}
void GPUIndex::SetGpuDevice(const int &gpu_id) {
......
......@@ -39,6 +39,9 @@ DatasetPtr IDMAP::Search(const DatasetPtr &dataset, const Config &config) {
}
auto k = config["k"].as<size_t>();
//auto metric_type = config["metric_type"].as_string() == "L2" ?
// faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT;
//index_->metric_type = metric_type;
GETTENSOR(dataset)
......@@ -135,11 +138,11 @@ VectorIndexPtr IDMAP::Clone() {
VectorIndexPtr IDMAP::CopyCpuToGpu(const int64_t &device_id, const Config &config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
ResScope rs(device_id, res);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res.get(), device_id, index_.get());
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<GPUIDMAP>(device_index, device_id);
return std::make_shared<GPUIDMAP>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
......@@ -204,7 +207,7 @@ void GPUIDMAP::LoadImpl(const BinarySet &index_binary) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_) ){
ResScope rs(gpu_id_, res);
auto device_index = faiss::gpu::index_cpu_to_gpu(res.get(), gpu_id_, index);
auto device_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index);
index_.reset(device_index);
} else {
KNOWHERE_THROW_MSG("Load error, can't get gpu resource");
......
......@@ -197,11 +197,11 @@ void IVF::search_impl(int64_t n,
VectorIndexPtr IVF::CopyCpuToGpu(const int64_t& device_id, const Config &config) {
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)){
ResScope rs(device_id, res);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res.get(), device_id, index_.get());
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<GPUIVF>(device_index, device_id);
return std::make_shared<GPUIVF>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
......@@ -275,11 +275,11 @@ VectorIndexPtr IVFSQ::CopyCpuToGpu(const int64_t &device_id, const Config &confi
faiss::gpu::GpuClonerOptions option;
option.allInGpu = true;
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res.get(), device_id, index_.get(), &option);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get(), &option);
std::shared_ptr<faiss::Index> device_index;
device_index.reset(gpu_index);
return std::make_shared<GPUIVFSQ>(device_index, device_id);
return std::make_shared<GPUIVFSQ>(device_index, device_id, res);
} else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
}
......@@ -350,12 +350,16 @@ void BasicIndex::LoadImpl(const BinarySet &index_binary) {
}
void BasicIndex::SealImpl() {
// TODO(linxj): enable
//#ifdef ZILLIZ_FAISS
faiss::Index *index = index_.get();
auto idx = dynamic_cast<faiss::IndexIVF *>(index);
if (idx != nullptr) {
idx->to_readonly();
}
//else {
// KNOHWERE_ERROR_MSG("Seal failed");
//}
//#endif
}
......
......@@ -18,9 +18,11 @@
using namespace zilliz::knowhere;
static int device_id = 0;
class IDMAPTest : public DataGen, public ::testing::Test {
protected:
void SetUp() override {
FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024*1024*200, 1024*1024*300, 2);
Init_with_default();
index_ = std::make_shared<IDMAP>();
}
......
......@@ -8,6 +8,12 @@
#include <iostream>
#include <sstream>
#include <thread>
#include <faiss/AutoTune.h>
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuClonerOptions.h>
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "knowhere/index/vector_index/ivf.h"
......@@ -25,7 +31,7 @@ using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::Combine;
static int device_id = 1;
static int device_id = 0;
IVFIndexPtr IndexFactory(const std::string &type) {
if (type == "IVF") {
return std::make_shared<IVF>();
......@@ -50,7 +56,7 @@ class IVFTest
//Init_with_default();
Generate(128, 1000000/5, 10);
index_ = IndexFactory(index_type);
FaissGpuResourceMgr::GetInstance().InitDevice(device_id);
FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024*1024*200, 1024*1024*300, 2);
}
protected:
......@@ -343,4 +349,213 @@ TEST_P(IVFTest, seal_test) {
ASSERT_GE(without_seal, with_seal);
}
class GPURESTEST
: public DataGen, public ::testing::Test {
protected:
void SetUp() override {
//std::tie(index_type, preprocess_cfg, train_cfg, add_cfg, search_cfg) = GetParam();
//Init_with_default();
Generate(128, 1000000, 1000);
k = 100;
//index_ = IndexFactory(index_type);
FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024*1024*200, 1024*1024*300, 2);
elems = nq * k;
ids = (int64_t *) malloc(sizeof(int64_t) * elems);
dis = (float *) malloc(sizeof(float) * elems);
}
void TearDown() override {
delete ids;
delete dis;
}
protected:
std::string index_type;
Config preprocess_cfg;
Config train_cfg;
Config add_cfg;
Config search_cfg;
IVFIndexPtr index_ = nullptr;
int64_t *ids = nullptr;
float *dis = nullptr;
int64_t elems = 0;
};
const int search_count = 100;
const int load_count = 30;
TEST_F(GPURESTEST, gpu_ivf_resource_test) {
assert(!xb.empty());
{
index_type = "GPUIVF";
index_ = IndexFactory(index_type);
auto preprocessor = index_->BuildPreprocessor(base_dataset, preprocess_cfg);
index_->set_preprocessor(preprocessor);
train_cfg = Config::object{{"nlist", 1638}, {"gpu_id", device_id}, {"metric_type", "L2"}};
auto model = index_->Train(base_dataset, train_cfg);
index_->set_index_model(model);
index_->Add(base_dataset, add_cfg);
EXPECT_EQ(index_->Count(), nb);
EXPECT_EQ(index_->Dimension(), dim);
search_cfg = Config::object{{"k", k}};
TimeRecorder tc("knowere GPUIVF");
for (int i = 0; i < search_count; ++i) {
index_->Search(query_dataset, search_cfg);
if (i > search_count - 6 || i < 5)
tc.RecordSection("search once");
}
tc.RecordSection("search all");
}
{
// IVF-Search
faiss::gpu::StandardGpuResources res;
faiss::gpu::GpuIndexIVFFlatConfig idx_config;
idx_config.device = device_id;
faiss::gpu::GpuIndexIVFFlat device_index(&res, dim, 1638, faiss::METRIC_L2, idx_config);
device_index.train(nb, xb.data());
device_index.add(nb, xb.data());
TimeRecorder tc("ori IVF");
for (int i = 0; i < search_count; ++i) {
device_index.search(nq, xq.data(), k, dis, ids);
if (i > search_count - 6 || i < 5)
tc.RecordSection("search once");
}
tc.RecordSection("search all");
}
}
TEST_F(GPURESTEST, gpuivfsq) {
{
// knowhere gpu ivfsq
index_type = "GPUIVFSQ";
index_ = IndexFactory(index_type);
auto preprocessor = index_->BuildPreprocessor(base_dataset, preprocess_cfg);
index_->set_preprocessor(preprocessor);
train_cfg = Config::object{{"gpu_id", device_id}, {"nlist", 1638}, {"nbits", 8}, {"metric_type", "L2"}};
auto model = index_->Train(base_dataset, train_cfg);
index_->set_index_model(model);
index_->Add(base_dataset, add_cfg);
search_cfg = Config::object{{"k", k}};
auto result = index_->Search(query_dataset, search_cfg);
AssertAnns(result, nq, k);
auto cpu_idx = CopyGpuToCpu(index_, Config());
cpu_idx->Seal();
TimeRecorder tc("knowhere GPUSQ8");
auto search_idx = CopyCpuToGpu(cpu_idx, device_id, Config());
tc.RecordSection("Copy to gpu");
for (int i = 0; i < search_count; ++i) {
search_idx->Search(query_dataset, search_cfg);
if (i > search_count - 6 || i < 5)
tc.RecordSection("search once");
}
tc.RecordSection("search all");
}
{
// Ori gpuivfsq Test
const char *index_description = "IVF1638,SQ8";
faiss::Index *ori_index = faiss::index_factory(dim, index_description, faiss::METRIC_L2);
faiss::gpu::StandardGpuResources res;
auto device_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, ori_index);
device_index->train(nb, xb.data());
device_index->add(nb, xb.data());
auto cpu_index = faiss::gpu::index_gpu_to_cpu(device_index);
auto idx = dynamic_cast<faiss::IndexIVF *>(cpu_index);
if (idx != nullptr) {
idx->to_readonly();
}
delete device_index;
delete ori_index;
faiss::gpu::GpuClonerOptions option;
option.allInGpu = true;
TimeRecorder tc("ori GPUSQ8");
faiss::Index *search_idx = faiss::gpu::index_cpu_to_gpu(&res, device_id, cpu_index, &option);
tc.RecordSection("Copy to gpu");
for (int i = 0; i < search_count; ++i) {
search_idx->search(nq, xq.data(), k, dis, ids);
if (i > search_count - 6 || i < 5)
tc.RecordSection("search once");
}
tc.RecordSection("search all");
delete cpu_index;
delete search_idx;
}
}
TEST_F(GPURESTEST, copyandsearch) {
printf("==================\n");
// search and copy at the same time
index_type = "GPUIVFSQ";
//index_type = "GPUIVF";
index_ = IndexFactory(index_type);
auto preprocessor = index_->BuildPreprocessor(base_dataset, preprocess_cfg);
index_->set_preprocessor(preprocessor);
train_cfg = Config::object{{"gpu_id", device_id}, {"nlist", 1638}, {"nbits", 8}, {"metric_type", "L2"}};
auto model = index_->Train(base_dataset, train_cfg);
index_->set_index_model(model);
index_->Add(base_dataset, add_cfg);
search_cfg = Config::object{{"k", k}};
auto result = index_->Search(query_dataset, search_cfg);
AssertAnns(result, nq, k);
auto cpu_idx = CopyGpuToCpu(index_, Config());
cpu_idx->Seal();
auto search_idx = CopyCpuToGpu(cpu_idx, device_id, Config());
auto search_func = [&] {
//TimeRecorder tc("search&load");
for (int i = 0; i < search_count; ++i) {
search_idx->Search(query_dataset, search_cfg);
//if (i > search_count - 6 || i == 0)
// tc.RecordSection("search once");
}
//tc.ElapseFromBegin("search finish");
};
auto load_func = [&] {
//TimeRecorder tc("search&load");
for (int i = 0; i < load_count; ++i) {
CopyCpuToGpu(cpu_idx, device_id, Config());
//if (i > load_count -5 || i < 5)
//tc.RecordSection("Copy to gpu");
}
//tc.ElapseFromBegin("load finish");
};
TimeRecorder tc("basic");
CopyCpuToGpu(cpu_idx, device_id, Config());
tc.RecordSection("Copy to gpu once");
search_idx->Search(query_dataset, search_cfg);
tc.RecordSection("search once");
search_func();
tc.RecordSection("only search total");
load_func();
tc.RecordSection("only copy total");
std::thread search_thread(search_func);
std::thread load_thread(load_func);
search_thread.join();
load_thread.join();
tc.RecordSection("Copy&search total");
}
// TODO(linxj): Add exception test
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "DBImpl.h"
#include "Factories.h"
namespace zilliz {
namespace milvus {
namespace engine {
DB::~DB() = default;
void DB::Open(const Options& options, DB** dbptr) {
*dbptr = DBFactory::Build(options);
}
} // namespace engine
} // namespace milvus
} // namespace zilliz
......@@ -20,7 +20,11 @@ class Env;
class DB {
public:
static void Open(const Options& options, DB** dbptr);
DB() = default;
DB(const DB&) = delete;
DB& operator=(const DB&) = delete;
virtual ~DB() = default;
virtual Status Start() = 0;
virtual Status Stop() = 0;
......@@ -55,11 +59,6 @@ public:
virtual Status DropAll() = 0;
DB() = default;
DB(const DB&) = delete;
DB& operator=(const DB&) = delete;
virtual ~DB() = 0;
}; // DB
} // namespace engine
......
......@@ -53,11 +53,15 @@ DBImpl::~DBImpl() {
Stop();
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//external api
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Status DBImpl::Start() {
if (!shutting_down_.load(std::memory_order_acquire)){
return Status::OK();
}
ENGINE_LOG_TRACE << "DB service start";
shutting_down_.store(false, std::memory_order_release);
//for distribute version, some nodes are read only
......@@ -75,30 +79,40 @@ Status DBImpl::Stop() {
}
shutting_down_.store(true, std::memory_order_release);
bg_timer_thread_.join();
//makesure all memory data serialized
MemSerialize();
//wait compaction/buildindex finish
for(auto& result : compact_thread_results_) {
result.wait();
}
bg_timer_thread_.join();
for(auto& result : index_thread_results_) {
result.wait();
if (options_.mode != Options::MODE::READ_ONLY) {
meta_ptr_->CleanUp();
}
//makesure all memory data serialized
MemSerialize();
ENGINE_LOG_TRACE << "DB service stop";
return Status::OK();
}
Status DBImpl::DropAll() {
return meta_ptr_->DropAll();
}
Status DBImpl::CreateTable(meta::TableSchema& table_schema) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
meta::TableSchema temp_schema = table_schema;
temp_schema.index_file_size_ *= ONE_MB;
temp_schema.index_file_size_ *= ONE_MB; //store as MB
return meta_ptr_->CreateTable(temp_schema);
}
Status DBImpl::DeleteTable(const std::string& table_id, const meta::DatesT& dates) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
//dates partly delete files of the table but currently we don't support
ENGINE_LOG_DEBUG << "Prepare to delete table " << table_id;
......@@ -121,18 +135,36 @@ Status DBImpl::DeleteTable(const std::string& table_id, const meta::DatesT& date
}
Status DBImpl::DescribeTable(meta::TableSchema& table_schema) {
return meta_ptr_->DescribeTable(table_schema);
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
auto stat = meta_ptr_->DescribeTable(table_schema);
table_schema.index_file_size_ /= ONE_MB; //return as MB
return stat;
}
Status DBImpl::HasTable(const std::string& table_id, bool& has_or_not) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
return meta_ptr_->HasTable(table_id, has_or_not);
}
Status DBImpl::AllTables(std::vector<meta::TableSchema>& table_schema_array) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
return meta_ptr_->AllTables(table_schema_array);
}
Status DBImpl::PreloadTable(const std::string &table_id) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
meta::DatePartionedTableFilesSchema files;
meta::DatesT dates;
......@@ -174,16 +206,27 @@ Status DBImpl::PreloadTable(const std::string &table_id) {
}
Status DBImpl::UpdateTableFlag(const std::string &table_id, int64_t flag) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
return meta_ptr_->UpdateTableFlag(table_id, flag);
}
Status DBImpl::GetTableRowCount(const std::string& table_id, uint64_t& row_count) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
return meta_ptr_->Count(table_id, row_count);
}
Status DBImpl::InsertVectors(const std::string& table_id_,
uint64_t n, const float* vectors, IDNumbers& vector_ids_) {
ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache";
// ENGINE_LOG_DEBUG << "Insert " << n << " vectors to cache";
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
Status status;
zilliz::milvus::server::CollectInsertMetrics metrics(n, status);
......@@ -191,13 +234,89 @@ Status DBImpl::InsertVectors(const std::string& table_id_,
// std::chrono::microseconds time_span = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
// double average_time = double(time_span.count()) / n;
ENGINE_LOG_DEBUG << "Insert vectors to cache finished";
// ENGINE_LOG_DEBUG << "Insert vectors to cache finished";
return status;
}
Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index) {
{
std::unique_lock<std::mutex> lock(build_index_mutex_);
//step 1: check index difference
TableIndex old_index;
auto status = DescribeIndex(table_id, old_index);
if(!status.ok()) {
ENGINE_LOG_ERROR << "Failed to get table index info for table: " << table_id;
return status;
}
//step 2: update index info
TableIndex new_index = index;
new_index.metric_type_ = old_index.metric_type_;//dont change metric type, it was defined by CreateTable
if(!utils::IsSameIndex(old_index, new_index)) {
DropIndex(table_id);
status = meta_ptr_->UpdateTableIndexParam(table_id, new_index);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to update table index info for table: " << table_id;
return status;
}
}
}
//step 3: wait and build index
//for IDMAP type, only wait all NEW file converted to RAW file
//for other type, wait NEW/RAW/NEW_MERGE/NEW_INDEX/TO_INDEX files converted to INDEX files
std::vector<int> file_types;
if(index.engine_type_ == (int)EngineType::FAISS_IDMAP) {
file_types = {
(int) meta::TableFileSchema::NEW,
(int) meta::TableFileSchema::NEW_MERGE,
};
} else {
file_types = {
(int) meta::TableFileSchema::RAW,
(int) meta::TableFileSchema::NEW,
(int) meta::TableFileSchema::NEW_MERGE,
(int) meta::TableFileSchema::NEW_INDEX,
(int) meta::TableFileSchema::TO_INDEX,
};
}
std::vector<std::string> file_ids;
auto status = meta_ptr_->FilesByType(table_id, file_types, file_ids);
int times = 1;
while (!file_ids.empty()) {
ENGINE_LOG_DEBUG << "Non index files detected! Will build index " << times;
if(index.engine_type_ != (int)EngineType::FAISS_IDMAP) {
status = meta_ptr_->UpdateTableFilesToIndex(table_id);
}
std::this_thread::sleep_for(std::chrono::milliseconds(std::min(10*1000, times*100)));
status = meta_ptr_->FilesByType(table_id, file_types, file_ids);
times++;
}
return Status::OK();
}
Status DBImpl::DescribeIndex(const std::string& table_id, TableIndex& index) {
return meta_ptr_->DescribeTableIndex(table_id, index);
}
Status DBImpl::DropIndex(const std::string& table_id) {
ENGINE_LOG_DEBUG << "Drop index for table: " << table_id;
return meta_ptr_->DropTableIndex(table_id);
}
Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float *vectors, QueryResults &results) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
meta::DatesT dates = {utils::GetDate()};
Status result = Query(table_id, k, nq, nprobe, vectors, dates, results);
......@@ -206,6 +325,10 @@ Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint6
Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe,
const float* vectors, const meta::DatesT& dates, QueryResults& results) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
ENGINE_LOG_DEBUG << "Query by dates for table: " << table_id;
//get all table files from table
......@@ -230,6 +353,10 @@ Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint6
Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>& file_ids,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
ENGINE_LOG_DEBUG << "Query by file ids for table: " << table_id;
//get specified files
......@@ -264,6 +391,18 @@ Status DBImpl::Query(const std::string& table_id, const std::vector<std::string>
return status;
}
Status DBImpl::Size(uint64_t& result) {
if (shutting_down_.load(std::memory_order_acquire)){
return Status::Error("Milsvus server is shutdown!");
}
return meta_ptr_->Size(result);
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//internal methods
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Status DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSchema& files,
uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors,
const meta::DatesT& dates, QueryResults& results) {
......@@ -522,7 +661,6 @@ void DBImpl::BackgroundCompaction(std::set<std::string> table_ids) {
status = BackgroundMergeFiles(table_id);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Merge files for table " << table_id << " failed: " << status.ToString();
continue;//let other table get chance to merge
}
if (shutting_down_.load(std::memory_order_acquire)){
......@@ -564,76 +702,6 @@ void DBImpl::StartBuildIndexTask(bool force) {
}
}
Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index) {
{
std::unique_lock<std::mutex> lock(build_index_mutex_);
//step 1: check index difference
TableIndex old_index;
auto status = DescribeIndex(table_id, old_index);
if(!status.ok()) {
ENGINE_LOG_ERROR << "Failed to get table index info for table: " << table_id;
return status;
}
//step 2: update index info
if(!utils::IsSameIndex(old_index, index)) {
DropIndex(table_id);
status = meta_ptr_->UpdateTableIndexParam(table_id, index);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to update table index info for table: " << table_id;
return status;
}
}
}
//step 3: wait and build index
//for IDMAP type, only wait all NEW file converted to RAW file
//for other type, wait NEW/RAW/NEW_MERGE/NEW_INDEX/TO_INDEX files converted to INDEX files
std::vector<int> file_types;
if(index.engine_type_ == (int)EngineType::FAISS_IDMAP) {
file_types = {
(int) meta::TableFileSchema::NEW,
(int) meta::TableFileSchema::NEW_MERGE,
};
} else {
file_types = {
(int) meta::TableFileSchema::RAW,
(int) meta::TableFileSchema::NEW,
(int) meta::TableFileSchema::NEW_MERGE,
(int) meta::TableFileSchema::NEW_INDEX,
(int) meta::TableFileSchema::TO_INDEX,
};
}
std::vector<std::string> file_ids;
auto status = meta_ptr_->FilesByType(table_id, file_types, file_ids);
int times = 1;
while (!file_ids.empty()) {
ENGINE_LOG_DEBUG << "Non index files detected! Will build index " << times;
if(index.engine_type_ != (int)EngineType::FAISS_IDMAP) {
status = meta_ptr_->UpdateTableFilesToIndex(table_id);
}
std::this_thread::sleep_for(std::chrono::milliseconds(std::min(10*1000, times*100)));
status = meta_ptr_->FilesByType(table_id, file_types, file_ids);
times++;
}
return Status::OK();
}
Status DBImpl::DescribeIndex(const std::string& table_id, TableIndex& index) {
return meta_ptr_->DescribeTableIndex(table_id, index);
}
Status DBImpl::DropIndex(const std::string& table_id) {
ENGINE_LOG_DEBUG << "Drop index for table: " << table_id;
return meta_ptr_->DropTableIndex(table_id);
}
Status DBImpl::BuildIndex(const meta::TableFileSchema& file) {
ExecutionEnginePtr to_index =
EngineFactory::Build(file.dimension_, file.location_, (EngineType)file.engine_type_,
......@@ -765,7 +833,6 @@ void DBImpl::BackgroundBuildIndex() {
status = BuildIndex(file);
if (!status.ok()) {
ENGINE_LOG_ERROR << "Building index for " << file.id_ << " failed: " << status.ToString();
return;
}
if (shutting_down_.load(std::memory_order_acquire)){
......@@ -777,14 +844,6 @@ void DBImpl::BackgroundBuildIndex() {
ENGINE_LOG_TRACE << "Background build index thread exit";
}
Status DBImpl::DropAll() {
return meta_ptr_->DropAll();
}
Status DBImpl::Size(uint64_t& result) {
return meta_ptr_->Size(result);
}
} // namespace engine
} // namespace milvus
} // namespace zilliz
......@@ -35,9 +35,11 @@ class DBImpl : public DB {
using MetaPtr = meta::Meta::Ptr;
explicit DBImpl(const Options &options);
~DBImpl();
Status Start() override;
Status Stop() override;
Status DropAll() override;
Status CreateTable(meta::TableSchema &table_schema) override;
......@@ -57,6 +59,12 @@ class DBImpl : public DB {
Status InsertVectors(const std::string &table_id, uint64_t n, const float *vectors, IDNumbers &vector_ids) override;
Status CreateIndex(const std::string& table_id, const TableIndex& index) override;
Status DescribeIndex(const std::string& table_id, TableIndex& index) override;
Status DropIndex(const std::string& table_id) override;
Status Query(const std::string &table_id,
uint64_t k,
uint64_t nq,
......@@ -81,18 +89,8 @@ class DBImpl : public DB {
const meta::DatesT &dates,
QueryResults &results) override;
Status DropAll() override;
Status Size(uint64_t &result) override;
Status CreateIndex(const std::string& table_id, const TableIndex& index) override;
Status DescribeIndex(const std::string& table_id, TableIndex& index) override;
Status DropIndex(const std::string& table_id) override;
~DBImpl() override;
private:
Status QueryAsync(const std::string &table_id,
const meta::TableFilesSchema &files,
......
......@@ -301,7 +301,8 @@ Status ExecutionEngineImpl::Search(long n,
}
ENGINE_LOG_DEBUG << "Search Params: [k] " << k << " [nprobe] " << nprobe;
auto ec = index_->Search(n, data, distances, labels, Config::object{{"k", k}, {"nprobe", nprobe}});
auto cfg = Config::object{{"k", k}, {"nprobe", nprobe}};
auto ec = index_->Search(n, data, distances, labels, cfg);
if (ec != server::KNOWHERE_SUCCESS) {
ENGINE_LOG_ERROR << "Search error";
return Status::Error("Search: Search Error");
......
......@@ -48,9 +48,7 @@ MySQLMetaImpl::MySQLMetaImpl(const DBMetaOptions &options_, const int &mode)
}
MySQLMetaImpl::~MySQLMetaImpl() {
if (mode_ != Options::MODE::READ_ONLY) {
CleanUp();
}
}
Status MySQLMetaImpl::NextTableId(std::string &table_id) {
......@@ -2001,10 +1999,8 @@ Status MySQLMetaImpl::Count(const std::string &table_id, uint64_t &result) {
}
Status MySQLMetaImpl::DropAll() {
if (boost::filesystem::is_directory(options_.path)) {
boost::filesystem::remove_all(options_.path);
}
try {
ENGINE_LOG_DEBUG << "Drop all mysql meta";
ScopedConnection connectionPtr(*mysql_connection_pool_, safe_grab);
if (connectionPtr == nullptr) {
......
......@@ -74,7 +74,7 @@ SqliteMetaImpl::SqliteMetaImpl(const DBMetaOptions &options_)
}
SqliteMetaImpl::~SqliteMetaImpl() {
CleanUp();
}
Status SqliteMetaImpl::NextTableId(std::string &table_id) {
......@@ -707,7 +707,7 @@ Status SqliteMetaImpl::FilesToSearch(const std::string &table_id,
files[table_file.date_].push_back(table_file);
}
if(files.empty()) {
std::cout << "ERROR" << std::endl;
ENGINE_LOG_ERROR << "No file to search for table: " << table_id;
}
} catch (std::exception &e) {
return HandleException("Encounter exception when iterate index files", e);
......@@ -1205,9 +1205,15 @@ Status SqliteMetaImpl::Count(const std::string &table_id, uint64_t &result) {
}
Status SqliteMetaImpl::DropAll() {
if (boost::filesystem::is_directory(options_.path)) {
boost::filesystem::remove_all(options_.path);
ENGINE_LOG_DEBUG << "Drop all sqlite meta";
try {
ConnectorPtr->drop_table("Tables");
ConnectorPtr->drop_table("TableFiles");
} catch (std::exception &e) {
return HandleException("Encounter exception when drop all meta", e);
}
return Status::OK();
}
......
......@@ -49,8 +49,8 @@ MilvusService::Stub::Stub(const std::shared_ptr< ::grpc::ChannelInterface>& chan
, rpcmethod_DropTable_(MilvusService_method_names[2], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_CreateIndex_(MilvusService_method_names[3], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_Insert_(MilvusService_method_names[4], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_Search_(MilvusService_method_names[5], ::grpc::internal::RpcMethod::SERVER_STREAMING, channel)
, rpcmethod_SearchInFiles_(MilvusService_method_names[6], ::grpc::internal::RpcMethod::SERVER_STREAMING, channel)
, rpcmethod_Search_(MilvusService_method_names[5], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_SearchInFiles_(MilvusService_method_names[6], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_DescribeTable_(MilvusService_method_names[7], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_CountTable_(MilvusService_method_names[8], ::grpc::internal::RpcMethod::NORMAL_RPC, channel)
, rpcmethod_ShowTables_(MilvusService_method_names[9], ::grpc::internal::RpcMethod::SERVER_STREAMING, channel)
......@@ -201,36 +201,60 @@ void MilvusService::Stub::experimental_async::Insert(::grpc::ClientContext* cont
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::VectorIds>::Create(channel_.get(), cq, rpcmethod_Insert_, context, request, false);
}
::grpc::ClientReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::SearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request) {
return ::grpc_impl::internal::ClientReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), rpcmethod_Search_, context, request);
::grpc::Status MilvusService::Stub::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::milvus::grpc::TopKQueryResultList* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Search_, context, request, response);
}
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, ::milvus::grpc::SearchParam* request, ::grpc::experimental::ClientReadReactor< ::milvus::grpc::TopKQueryResult>* reactor) {
::grpc_impl::internal::ClientCallbackReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, reactor);
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f));
}
::grpc::ClientAsyncReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::AsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq, void* tag) {
return ::grpc_impl::internal::ClientAsyncReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, true, tag);
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, std::move(f));
}
::grpc::ClientAsyncReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::PrepareAsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, false, nullptr);
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor);
}
::grpc::ClientReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::SearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request) {
return ::grpc_impl::internal::ClientReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), rpcmethod_SearchInFiles_, context, request);
void MilvusService::Stub::experimental_async::Search(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_Search_, context, request, response, reactor);
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, ::milvus::grpc::SearchInFilesParam* request, ::grpc::experimental::ClientReadReactor< ::milvus::grpc::TopKQueryResult>* reactor) {
::grpc_impl::internal::ClientCallbackReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, reactor);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::AsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, true);
}
::grpc::ClientAsyncReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::AsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq, void* tag) {
return ::grpc_impl::internal::ClientAsyncReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, true, tag);
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::PrepareAsyncSearchRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_Search_, context, request, false);
}
::grpc::ClientAsyncReader< ::milvus::grpc::TopKQueryResult>* MilvusService::Stub::PrepareAsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncReaderFactory< ::milvus::grpc::TopKQueryResult>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, false, nullptr);
::grpc::Status MilvusService::Stub::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::milvus::grpc::TopKQueryResultList* response) {
return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_SearchInFiles_, context, request, response);
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, std::function<void(::grpc::Status)> f) {
::grpc_impl::internal::CallbackUnaryCall(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, std::move(f));
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor);
}
void MilvusService::Stub::experimental_async::SearchInFiles(::grpc::ClientContext* context, const ::grpc::ByteBuffer* request, ::milvus::grpc::TopKQueryResultList* response, ::grpc::experimental::ClientUnaryReactor* reactor) {
::grpc_impl::internal::ClientCallbackUnaryFactory::Create(stub_->channel_.get(), stub_->rpcmethod_SearchInFiles_, context, request, response, reactor);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::AsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, true);
}
::grpc::ClientAsyncResponseReader< ::milvus::grpc::TopKQueryResultList>* MilvusService::Stub::PrepareAsyncSearchInFilesRaw(::grpc::ClientContext* context, const ::milvus::grpc::SearchInFilesParam& request, ::grpc::CompletionQueue* cq) {
return ::grpc_impl::internal::ClientAsyncResponseReaderFactory< ::milvus::grpc::TopKQueryResultList>::Create(channel_.get(), cq, rpcmethod_SearchInFiles_, context, request, false);
}
::grpc::Status MilvusService::Stub::DescribeTable(::grpc::ClientContext* context, const ::milvus::grpc::TableName& request, ::milvus::grpc::TableSchema* response) {
......@@ -473,13 +497,13 @@ MilvusService::Service::Service() {
std::mem_fn(&MilvusService::Service::Insert), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[5],
::grpc::internal::RpcMethod::SERVER_STREAMING,
new ::grpc::internal::ServerStreamingHandler< MilvusService::Service, ::milvus::grpc::SearchParam, ::milvus::grpc::TopKQueryResult>(
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchParam, ::milvus::grpc::TopKQueryResultList>(
std::mem_fn(&MilvusService::Service::Search), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[6],
::grpc::internal::RpcMethod::SERVER_STREAMING,
new ::grpc::internal::ServerStreamingHandler< MilvusService::Service, ::milvus::grpc::SearchInFilesParam, ::milvus::grpc::TopKQueryResult>(
::grpc::internal::RpcMethod::NORMAL_RPC,
new ::grpc::internal::RpcMethodHandler< MilvusService::Service, ::milvus::grpc::SearchInFilesParam, ::milvus::grpc::TopKQueryResultList>(
std::mem_fn(&MilvusService::Service::SearchInFiles), this)));
AddMethod(new ::grpc::internal::RpcServiceMethod(
MilvusService_method_names[7],
......@@ -561,17 +585,17 @@ MilvusService::Service::~Service() {
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, ::grpc::ServerWriter< ::milvus::grpc::TopKQueryResult>* writer) {
::grpc::Status MilvusService::Service::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request, ::milvus::grpc::TopKQueryResultList* response) {
(void) context;
(void) request;
(void) writer;
(void) response;
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
::grpc::Status MilvusService::Service::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::grpc::ServerWriter< ::milvus::grpc::TopKQueryResult>* writer) {
::grpc::Status MilvusService::Service::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request, ::milvus::grpc::TopKQueryResultList* response) {
(void) context;
(void) request;
(void) writer;
(void) response;
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
......
此差异已折叠。
......@@ -19,6 +19,7 @@ message TableSchema {
TableName table_name = 1;
int64 dimension = 2;
int64 index_file_size = 3;
int32 metric_type = 4;
}
/**
......@@ -84,8 +85,15 @@ message QueryResult {
* @brief TopK query result
*/
message TopKQueryResult {
repeated QueryResult query_result_arrays = 1;
}
/**
* @brief List of topK query result
*/
message TopKQueryResultList {
Status status = 1;
repeated QueryResult query_result_arrays = 2;
repeated TopKQueryResult topk_query_result = 2;
}
/**
......@@ -127,7 +135,6 @@ message Command {
message Index {
int32 index_type = 1;
int32 nlist = 2;
int32 metric_type = 3;
}
/**
......@@ -211,7 +218,7 @@ service MilvusService {
*
* @return query result array.
*/
rpc Search(SearchParam) returns (stream TopKQueryResult) {}
rpc Search(SearchParam) returns (TopKQueryResultList) {}
/**
* @brief Internal use query interface
......@@ -225,7 +232,7 @@ service MilvusService {
*
* @return query result array.
*/
rpc SearchInFiles(SearchInFilesParam) returns (stream TopKQueryResult) {}
rpc SearchInFiles(SearchInFilesParam) returns (TopKQueryResultList) {}
/**
* @brief Get table schema
......
......@@ -29,17 +29,34 @@ class Metrics {
private:
static MetricsBase &CreateMetricsCollector();
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectMetricsBase {
protected:
CollectMetricsBase() {
start_time_ = METRICS_NOW_TIME;
}
virtual ~CollectMetricsBase() = default;
double TimeFromBegine() {
auto end_time = METRICS_NOW_TIME;
return METRICS_MICROSECONDS(start_time_, end_time);
}
protected:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
};
class CollectInsertMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectInsertMetrics : CollectMetricsBase {
public:
CollectInsertMetrics(size_t n, engine::Status& status) : n_(n), status_(status) {
start_time_ = METRICS_NOW_TIME;
}
~CollectInsertMetrics() {
if(n_ > 0) {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
double avg_time = total_time / n_;
for (int i = 0; i < n_; ++i) {
Metrics::GetInstance().AddVectorsDurationHistogramOberve(avg_time);
......@@ -57,22 +74,19 @@ public:
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
size_t n_;
engine::Status& status_;
};
class CollectQueryMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectQueryMetrics : CollectMetricsBase {
public:
CollectQueryMetrics(size_t nq) : nq_(nq) {
start_time_ = METRICS_NOW_TIME;
}
~CollectQueryMetrics() {
if(nq_ > 0) {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
for (int i = 0; i < nq_; ++i) {
server::Metrics::GetInstance().QueryResponseSummaryObserve(total_time);
}
......@@ -83,112 +97,90 @@ public:
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
size_t nq_;
};
class CollectMergeFilesMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectMergeFilesMetrics : CollectMetricsBase {
public:
CollectMergeFilesMetrics() {
start_time_ = METRICS_NOW_TIME;
}
~CollectMergeFilesMetrics() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
server::Metrics::GetInstance().MemTableMergeDurationSecondsHistogramObserve(total_time);
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
};
class CollectBuildIndexMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectBuildIndexMetrics : CollectMetricsBase {
public:
CollectBuildIndexMetrics() {
start_time_ = METRICS_NOW_TIME;
}
~CollectBuildIndexMetrics() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
server::Metrics::GetInstance().BuildIndexDurationSecondsHistogramObserve(total_time);
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
};
class CollectExecutionEngineMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectExecutionEngineMetrics : CollectMetricsBase {
public:
CollectExecutionEngineMetrics(double physical_size) : physical_size_(physical_size) {
start_time_ = METRICS_NOW_TIME;
}
~CollectExecutionEngineMetrics() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
server::Metrics::GetInstance().FaissDiskLoadDurationSecondsHistogramObserve(total_time);
server::Metrics::GetInstance().FaissDiskLoadSizeBytesHistogramObserve(physical_size_);
server::Metrics::GetInstance().FaissDiskLoadIOSpeedGaugeSet(physical_size_ / double(total_time));
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
double physical_size_;
};
class CollectSerializeMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectSerializeMetrics : CollectMetricsBase {
public:
CollectSerializeMetrics(size_t size) : size_(size) {
start_time_ = METRICS_NOW_TIME;
}
~CollectSerializeMetrics() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
server::Metrics::GetInstance().DiskStoreIOSpeedGaugeSet((double) size_ / total_time);
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
size_t size_;
};
class CollectAddMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectAddMetrics : CollectMetricsBase {
public:
CollectAddMetrics(size_t n, uint16_t dimension) : n_(n), dimension_(dimension) {
start_time_ = METRICS_NOW_TIME;
}
~CollectAddMetrics() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
server::Metrics::GetInstance().AddVectorsPerSecondGaugeSet(static_cast<int>(n_),
static_cast<int>(dimension_),
total_time);
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
size_t n_;
uint16_t dimension_;
};
class CollectDurationMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectDurationMetrics : CollectMetricsBase {
public:
CollectDurationMetrics(int index_type) : index_type_(index_type) {
start_time_ = METRICS_NOW_TIME;
}
~CollectDurationMetrics() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
switch (index_type_) {
case engine::meta::TableFileSchema::RAW: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
......@@ -205,20 +197,17 @@ public:
}
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
int index_type_;
};
class CollectSearchTaskMetrics {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CollectSearchTaskMetrics : CollectMetricsBase {
public:
CollectSearchTaskMetrics(int index_type) : index_type_(index_type) {
start_time_ = METRICS_NOW_TIME;
}
~CollectSearchTaskMetrics() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
switch(index_type_) {
case engine::meta::TableFileSchema::RAW: {
server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
......@@ -236,27 +225,20 @@ public:
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
int index_type_;
};
class MetricCollector {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class MetricCollector : CollectMetricsBase {
public:
MetricCollector() {
server::Metrics::GetInstance().MetaAccessTotalIncrement();
start_time_ = METRICS_NOW_TIME;
}
~MetricCollector() {
auto end_time = METRICS_NOW_TIME;
auto total_time = METRICS_MICROSECONDS(start_time_, end_time);
auto total_time = TimeFromBegine();
server::Metrics::GetInstance().MetaAccessDurationSecondsHistogramObserve(total_time);
}
private:
using TIME_POINT = std::chrono::system_clock::time_point;
TIME_POINT start_time_;
};
......
......@@ -20,12 +20,12 @@ ShortestPath(const ResourcePtr &src,
std::vector<std::vector<std::string>> paths;
uint64_t num_of_resources = res_mgr->GetAllResouces().size();
uint64_t num_of_resources = res_mgr->GetAllResources().size();
std::unordered_map<uint64_t, std::string> id_name_map;
std::unordered_map<std::string, uint64_t> name_id_map;
for (uint64_t i = 0; i < num_of_resources; ++i) {
id_name_map.insert(std::make_pair(i, res_mgr->GetAllResouces().at(i)->Name()));
name_id_map.insert(std::make_pair(res_mgr->GetAllResouces().at(i)->Name(), i));
id_name_map.insert(std::make_pair(i, res_mgr->GetAllResources().at(i)->name()));
name_id_map.insert(std::make_pair(res_mgr->GetAllResources().at(i)->name(), i));
}
std::vector<std::vector<uint64_t> > dis_matrix;
......@@ -40,23 +40,23 @@ ShortestPath(const ResourcePtr &src,
std::vector<bool> vis(num_of_resources, false);
std::vector<uint64_t> dis(num_of_resources, MAXINT);
for (auto &res : res_mgr->GetAllResouces()) {
for (auto &res : res_mgr->GetAllResources()) {
auto cur_node = std::static_pointer_cast<Node>(res);
auto cur_neighbours = cur_node->GetNeighbours();
for (auto &neighbour : cur_neighbours) {
auto neighbour_res = std::static_pointer_cast<Resource>(neighbour.neighbour_node.lock());
dis_matrix[name_id_map.at(res->Name())][name_id_map.at(neighbour_res->Name())] =
dis_matrix[name_id_map.at(res->name())][name_id_map.at(neighbour_res->name())] =
neighbour.connection.transport_cost();
}
}
for (uint64_t i = 0; i < num_of_resources; ++i) {
dis[i] = dis_matrix[name_id_map.at(src->Name())][i];
dis[i] = dis_matrix[name_id_map.at(src->name())][i];
}
vis[name_id_map.at(src->Name())] = true;
vis[name_id_map.at(src->name())] = true;
std::vector<int64_t> parent(num_of_resources, -1);
for (uint64_t i = 0; i < num_of_resources; ++i) {
......@@ -71,7 +71,7 @@ ShortestPath(const ResourcePtr &src,
vis[temp] = true;
if (i == 0) {
parent[temp] = name_id_map.at(src->Name());
parent[temp] = name_id_map.at(src->name());
}
for (uint64_t j = 0; j < num_of_resources; ++j) {
......@@ -82,15 +82,15 @@ ShortestPath(const ResourcePtr &src,
}
}
int64_t parent_idx = parent[name_id_map.at(dest->Name())];
int64_t parent_idx = parent[name_id_map.at(dest->name())];
if (parent_idx != -1) {
path.push_back(dest->Name());
path.push_back(dest->name());
}
while (parent_idx != -1) {
path.push_back(id_name_map.at(parent_idx));
parent_idx = parent[parent_idx];
}
return dis[name_id_map.at(dest->Name())];
return dis[name_id_map.at(dest->name())];
}
}
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <memory>
namespace zilliz {
namespace milvus {
namespace engine {
// dummy cache_mgr
class CacheMgr {
};
using CacheMgrPtr = std::shared_ptr<CacheMgr>;
}
}
}
......@@ -12,67 +12,31 @@ namespace zilliz {
namespace milvus {
namespace engine {
ResourceMgr::ResourceMgr()
: running_(false) {
}
uint64_t
ResourceMgr::GetNumOfComputeResource() {
uint64_t count = 0;
for (auto &res : resources_) {
if (res->HasExecutor()) {
++count;
}
}
return count;
}
std::vector<ResourcePtr>
ResourceMgr::GetComputeResource() {
std::vector<ResourcePtr > result;
void
ResourceMgr::Start() {
std::lock_guard<std::mutex> lck(resources_mutex_);
for (auto &resource : resources_) {
if (resource->HasExecutor()) {
result.emplace_back(resource);
}
}
return result;
}
uint64_t
ResourceMgr::GetNumGpuResource() const {
uint64_t num = 0;
for (auto &res : resources_) {
if (res->Type() == ResourceType::GPU) {
num++;
}
resource->Start();
}
return num;
running_ = true;
worker_thread_ = std::thread(&ResourceMgr::event_process, this);
}
ResourcePtr
ResourceMgr::GetResource(ResourceType type, uint64_t device_id) {
for (auto &resource : resources_) {
if (resource->Type() == type && resource->DeviceId() == device_id) {
return resource;
}
void
ResourceMgr::Stop() {
{
std::lock_guard<std::mutex> lock(event_mutex_);
running_ = false;
queue_.push(nullptr);
event_cv_.notify_one();
}
return nullptr;
}
worker_thread_.join();
ResourcePtr
ResourceMgr::GetResourceByName(std::string name) {
std::lock_guard<std::mutex> lck(resources_mutex_);
for (auto &resource : resources_) {
if (resource->Name() == name) {
return resource;
}
resource->Stop();
}
return nullptr;
}
std::vector<ResourcePtr>
ResourceMgr::GetAllResouces() {
return resources_;
}
ResourceWPtr
......@@ -85,75 +49,85 @@ ResourceMgr::Add(ResourcePtr &&resource) {
return ret;
}
if (resource->Type() == ResourceType::DISK) {
resource->RegisterSubscriber(std::bind(&ResourceMgr::post_event, this, std::placeholders::_1));
if (resource->type() == ResourceType::DISK) {
disk_resources_.emplace_back(ResourceWPtr(resource));
}
resources_.emplace_back(resource);
size_t index = resources_.size() - 1;
resource->RegisterSubscriber(std::bind(&ResourceMgr::PostEvent, this, std::placeholders::_1));
return ret;
}
void
ResourceMgr::Connect(const std::string &name1, const std::string &name2, Connection &connection) {
auto res1 = get_resource_by_name(name1);
auto res2 = get_resource_by_name(name2);
auto res1 = GetResource(name1);
auto res2 = GetResource(name2);
if (res1 && res2) {
res1->AddNeighbour(std::static_pointer_cast<Node>(res2), connection);
// TODO: enable when task balance supported
// res2->AddNeighbour(std::static_pointer_cast<Node>(res1), connection);
}
}
void
ResourceMgr::Connect(ResourceWPtr &res1, ResourceWPtr &res2, Connection &connection) {
if (auto observe_a = res1.lock()) {
if (auto observe_b = res2.lock()) {
observe_a->AddNeighbour(std::static_pointer_cast<Node>(observe_b), connection);
observe_b->AddNeighbour(std::static_pointer_cast<Node>(observe_a), connection);
}
}
ResourceMgr::Clear() {
std::lock_guard<std::mutex> lck(resources_mutex_);
disk_resources_.clear();
resources_.clear();
}
void
ResourceMgr::Start() {
std::lock_guard<std::mutex> lck(resources_mutex_);
std::vector<ResourcePtr>
ResourceMgr::GetComputeResource() {
std::vector<ResourcePtr> result;
for (auto &resource : resources_) {
resource->Start();
if (resource->HasExecutor()) {
result.emplace_back(resource);
}
}
running_ = true;
worker_thread_ = std::thread(&ResourceMgr::event_process, this);
return result;
}
void
ResourceMgr::Stop() {
{
std::lock_guard<std::mutex> lock(event_mutex_);
running_ = false;
queue_.push(nullptr);
event_cv_.notify_one();
ResourcePtr
ResourceMgr::GetResource(ResourceType type, uint64_t device_id) {
for (auto &resource : resources_) {
if (resource->type() == type && resource->device_id() == device_id) {
return resource;
}
}
worker_thread_.join();
return nullptr;
}
std::lock_guard<std::mutex> lck(resources_mutex_);
ResourcePtr
ResourceMgr::GetResource(const std::string &name) {
for (auto &resource : resources_) {
resource->Stop();
if (resource->name() == name) {
return resource;
}
}
return nullptr;
}
void
ResourceMgr::Clear() {
std::lock_guard<std::mutex> lck(resources_mutex_);
disk_resources_.clear();
resources_.clear();
uint64_t
ResourceMgr::GetNumOfComputeResource() {
uint64_t count = 0;
for (auto &res : resources_) {
if (res->HasExecutor()) {
++count;
}
}
return count;
}
void
ResourceMgr::PostEvent(const EventPtr &event) {
std::lock_guard<std::mutex> lock(event_mutex_);
queue_.emplace(event);
event_cv_.notify_one();
uint64_t
ResourceMgr::GetNumGpuResource() const {
uint64_t num = 0;
for (auto &res : resources_) {
if (res->type() == ResourceType::GPU) {
num++;
}
}
return num;
}
std::string
......@@ -180,14 +154,13 @@ ResourceMgr::DumpTaskTables() {
return ss.str();
}
ResourcePtr
ResourceMgr::get_resource_by_name(const std::string &name) {
for (auto &res : resources_) {
if (res->Name() == name) {
return res;
}
void
ResourceMgr::post_event(const EventPtr &event) {
{
std::lock_guard<std::mutex> lock(event_mutex_);
queue_.emplace(event);
}
return nullptr;
event_cv_.notify_one();
}
void
......@@ -203,8 +176,6 @@ ResourceMgr::event_process() {
break;
}
// ENGINE_LOG_DEBUG << "ResourceMgr process " << *event;
if (subscriber_) {
subscriber_(event);
}
......
......@@ -22,78 +22,63 @@ namespace engine {
class ResourceMgr {
public:
ResourceMgr();
ResourceMgr() = default;
public:
/******** Management Interface ********/
void
Start();
void
Stop();
ResourceWPtr
Add(ResourcePtr &&resource);
void
Connect(const std::string &res1, const std::string &res2, Connection &connection);
void
Clear();
inline void
RegisterSubscriber(std::function<void(EventPtr)> subscriber) {
subscriber_ = std::move(subscriber);
}
std::vector<ResourceWPtr> &
public:
/******** Management Interface ********/
inline std::vector<ResourceWPtr> &
GetDiskResources() {
return disk_resources_;
}
uint64_t
GetNumGpuResource() const;
// TODO: why return shared pointer
inline std::vector<ResourcePtr>
GetAllResources() {
return resources_;
}
std::vector<ResourcePtr>
GetComputeResource();
ResourcePtr
GetResource(ResourceType type, uint64_t device_id);
ResourcePtr
GetResourceByName(std::string name);
GetResource(const std::string &name);
std::vector<ResourcePtr>
GetAllResouces();
/*
* Return account of resource which enable executor;
*/
uint64_t
GetNumOfComputeResource();
std::vector<ResourcePtr>
GetComputeResource();
/*
* Add resource into Resource Management;
* Generate functions on events;
* Functions only modify bool variable, like event trigger;
*/
ResourceWPtr
Add(ResourcePtr &&resource);
void
Connect(const std::string &res1, const std::string &res2, Connection &connection);
/*
* Create connection between A and B;
*/
void
Connect(ResourceWPtr &res1, ResourceWPtr &res2, Connection &connection);
/*
* Synchronous start all resource;
* Last, start event process thread;
*/
void
Start();
void
Stop();
void
Clear();
void
PostEvent(const EventPtr &event);
uint64_t
GetNumGpuResource() const;
public:
// TODO: add stats interface(low)
public:
/******** Utlitity Functions ********/
/******** Utility Functions ********/
std::string
Dump();
......@@ -101,26 +86,26 @@ public:
DumpTaskTables();
private:
ResourcePtr
get_resource_by_name(const std::string &name);
void
post_event(const EventPtr &event);
void
event_process();
private:
std::queue<EventPtr> queue_;
std::function<void(EventPtr)> subscriber_ = nullptr;
bool running_;
bool running_ = false;
std::vector<ResourceWPtr> disk_resources_;
std::vector<ResourcePtr> resources_;
mutable std::mutex resources_mutex_;
std::thread worker_thread_;
std::queue<EventPtr> queue_;
std::function<void(EventPtr)> subscriber_ = nullptr;
std::mutex event_mutex_;
std::condition_variable event_cv_;
std::thread worker_thread_;
};
using ResourceMgrPtr = std::shared_ptr<ResourceMgr>;
......
......@@ -21,43 +21,67 @@ std::mutex SchedInst::mutex_;
void
StartSchedulerService() {
server::ConfigNode &config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_RESOURCE);
auto resources = config.GetChild(server::CONFIG_RESOURCES).GetChildren();
for (auto &resource : resources) {
auto &resname = resource.first;
auto &resconf = resource.second;
auto type = resconf.GetValue(server::CONFIG_RESOURCE_TYPE);
try {
server::ConfigNode &config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_RESOURCE);
if (config.GetChildren().empty()) throw "resource_config null exception";
auto resources = config.GetChild(server::CONFIG_RESOURCES).GetChildren();
if (resources.empty()) throw "Children of resource_config null exception";
for (auto &resource : resources) {
auto &resname = resource.first;
auto &resconf = resource.second;
auto type = resconf.GetValue(server::CONFIG_RESOURCE_TYPE);
// auto memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_MEMORY);
auto device_id = resconf.GetInt64Value(server::CONFIG_RESOURCE_DEVICE_ID);
auto enable_loader = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_LOADER);
auto enable_executor = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_EXECUTOR);
auto device_id = resconf.GetInt64Value(server::CONFIG_RESOURCE_DEVICE_ID);
auto enable_loader = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_LOADER);
auto enable_executor = resconf.GetBoolValue(server::CONFIG_RESOURCE_ENABLE_EXECUTOR);
auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY);
auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY);
auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM);
ResMgrInst::GetInstance()->Add(ResourceFactory::Create(resname,
type,
device_id,
enable_loader,
enable_executor));
auto res = ResMgrInst::GetInstance()->Add(ResourceFactory::Create(resname,
type,
device_id,
enable_loader,
enable_executor));
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(device_id);
}
if (res.lock()->type() == ResourceType::GPU) {
auto pinned_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_PIN_MEMORY, 300);
auto temp_memory = resconf.GetInt64Value(server::CONFIG_RESOURCE_TEMP_MEMORY, 300);
auto resource_num = resconf.GetInt64Value(server::CONFIG_RESOURCE_NUM, 2);
pinned_memory = 1024 * 1024 * pinned_memory;
temp_memory = 1024 * 1024 * temp_memory;
knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(device_id,
pinned_memory,
temp_memory,
resource_num);
}
}
knowhere::FaissGpuResourceMgr::GetInstance().InitResource();
knowhere::FaissGpuResourceMgr::GetInstance().InitResource();
// auto default_connection = Connection("default_connection", 500.0);
auto connections = config.GetChild(server::CONFIG_RESOURCE_CONNECTIONS).GetChildren();
for (auto &conn : connections) {
auto &connect_name = conn.first;
auto &connect_conf = conn.second;
auto connect_speed = connect_conf.GetInt64Value(server::CONFIG_SPEED_CONNECTIONS);
auto connect_endpoint = connect_conf.GetValue(server::CONFIG_ENDPOINT_CONNECTIONS);
auto connections = config.GetChild(server::CONFIG_RESOURCE_CONNECTIONS).GetChildren();
if(connections.empty()) throw "connections config null exception";
for (auto &conn : connections) {
auto &connect_name = conn.first;
auto &connect_conf = conn.second;
auto connect_speed = connect_conf.GetInt64Value(server::CONFIG_SPEED_CONNECTIONS);
auto connect_endpoint = connect_conf.GetValue(server::CONFIG_ENDPOINT_CONNECTIONS);
std::string delimiter = "===";
std::string left = connect_endpoint.substr(0, connect_endpoint.find(delimiter));
std::string right = connect_endpoint.substr(connect_endpoint.find(delimiter) + 3,
connect_endpoint.length());
std::string delimiter = "===";
std::string left = connect_endpoint.substr(0, connect_endpoint.find(delimiter));
std::string right = connect_endpoint.substr(connect_endpoint.find(delimiter) + 3,
connect_endpoint.length());
auto connection = Connection(connect_name, connect_speed);
ResMgrInst::GetInstance()->Connect(left, right, connection);
auto connection = Connection(connect_name, connect_speed);
ResMgrInst::GetInstance()->Connect(left, right, connection);
}
} catch (const char* msg) {
SERVER_LOG_ERROR << msg;
exit(-1);
}
ResMgrInst::GetInstance()->Start();
......
......@@ -143,7 +143,7 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
auto task = load_completed_event->task_table_item_->task;
// if this resource is disk, assign it to smallest cost resource
if (self->Type() == ResourceType::DISK) {
if (self->type() == ResourceType::DISK) {
// step 1: calculate shortest path per resource, from disk to compute resource
auto compute_resources = res_mgr_.lock()->GetComputeResource();
std::vector<std::vector<std::string>> paths;
......@@ -176,11 +176,11 @@ Scheduler::OnLoadCompleted(const EventPtr &event) {
task->path() = task_path;
}
if(self->Name() == task->path().Last()) {
if(self->name() == task->path().Last()) {
self->WakeupLoader();
} else {
auto next_res_name = task->path().Next();
auto next_res = res_mgr_.lock()->GetResourceByName(next_res_name);
auto next_res = res_mgr_.lock()->GetResource(next_res_name);
load_completed_event->task_table_item_->Move();
next_res->task_table().Put(task);
}
......
......@@ -6,6 +6,8 @@
#include "TaskTable.h"
#include "event/TaskTableUpdatedEvent.h"
#include "Utils.h"
#include <vector>
#include <sstream>
#include <ctime>
......@@ -15,14 +17,6 @@ namespace zilliz {
namespace milvus {
namespace engine {
uint64_t
get_now_timestamp() {
std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
auto duration = now.time_since_epoch();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
return millis;
}
std::string
ToString(TaskTableItemState state) {
switch (state) {
......@@ -64,7 +58,7 @@ TaskTableItem::Load() {
if (state == TaskTableItemState::START) {
state = TaskTableItemState::LOADING;
lock.unlock();
timestamp.load = get_now_timestamp();
timestamp.load = get_current_timestamp();
return true;
}
return false;
......@@ -75,7 +69,7 @@ TaskTableItem::Loaded() {
if (state == TaskTableItemState::LOADING) {
state = TaskTableItemState::LOADED;
lock.unlock();
timestamp.loaded = get_now_timestamp();
timestamp.loaded = get_current_timestamp();
return true;
}
return false;
......@@ -86,7 +80,7 @@ TaskTableItem::Execute() {
if (state == TaskTableItemState::LOADED) {
state = TaskTableItemState::EXECUTING;
lock.unlock();
timestamp.execute = get_now_timestamp();
timestamp.execute = get_current_timestamp();
return true;
}
return false;
......@@ -97,8 +91,8 @@ TaskTableItem::Executed() {
if (state == TaskTableItemState::EXECUTING) {
state = TaskTableItemState::EXECUTED;
lock.unlock();
timestamp.executed = get_now_timestamp();
timestamp.finish = get_now_timestamp();
timestamp.executed = get_current_timestamp();
timestamp.finish = get_current_timestamp();
return true;
}
return false;
......@@ -109,7 +103,7 @@ TaskTableItem::Move() {
if (state == TaskTableItemState::LOADED) {
state = TaskTableItemState::MOVING;
lock.unlock();
timestamp.move = get_now_timestamp();
timestamp.move = get_current_timestamp();
return true;
}
return false;
......@@ -120,8 +114,8 @@ TaskTableItem::Moved() {
if (state == TaskTableItemState::MOVING) {
state = TaskTableItemState::MOVED;
lock.unlock();
timestamp.moved = get_now_timestamp();
timestamp.finish = get_now_timestamp();
timestamp.moved = get_current_timestamp();
timestamp.finish = get_current_timestamp();
return true;
}
return false;
......@@ -177,7 +171,7 @@ TaskTable::Put(TaskPtr task) {
item->id = id_++;
item->task = std::move(task);
item->state = TaskTableItemState::START;
item->timestamp.start = get_now_timestamp();
item->timestamp.start = get_current_timestamp();
table_.push_back(item);
if (subscriber_) {
subscriber_();
......@@ -192,7 +186,7 @@ TaskTable::Put(std::vector<TaskPtr> &tasks) {
item->id = id_++;
item->task = std::move(task);
item->state = TaskTableItemState::START;
item->timestamp.start = get_now_timestamp();
item->timestamp.start = get_current_timestamp();
table_.push_back(item);
}
if (subscriber_) {
......
......@@ -40,20 +40,17 @@ struct TaskTimestamp {
};
struct TaskTableItem {
TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex(), priority(0) {}
TaskTableItem() : id(0), state(TaskTableItemState::INVALID), mutex() {}
TaskTableItem(const TaskTableItem &src)
: id(src.id), state(src.state), mutex(), priority(src.priority) {}
: id(src.id), state(src.state), mutex() {}
uint64_t id; // auto increment from 0;
// TODO: add tag into task
TaskPtr task; // the task;
TaskTableItemState state; // the state;
std::mutex mutex;
TaskTimestamp timestamp;
uint8_t priority; // just a number, meaningless;
bool
IsFinish();
......@@ -113,7 +110,7 @@ public:
Get(uint64_t index);
/*
* TODO
* TODO(wxyu): BIG GC
* Remove sequence task which is DONE or MOVED from front;
* Called by ?
*/
......@@ -135,6 +132,7 @@ public:
Size() {
return table_.size();
}
public:
TaskTableItemPtr &
operator[](uint64_t index) {
......@@ -225,7 +223,6 @@ public:
Dump();
private:
// TODO: map better ?
std::uint64_t id_ = 0;
mutable std::mutex id_mutex_;
std::deque<TaskTableItemPtr> table_;
......
......@@ -4,16 +4,17 @@
* Proprietary and confidential.
******************************************************************************/
#include <chrono>
#include "Utils.h"
#include <chrono>
namespace zilliz {
namespace milvus {
namespace engine {
uint64_t
get_current_timestamp()
{
get_current_timestamp() {
std::chrono::time_point<std::chrono::system_clock> now = std::chrono::system_clock::now();
auto duration = now.time_since_epoch();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
......
......@@ -3,6 +3,7 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include <cstdint>
......
......@@ -15,6 +15,7 @@ namespace engine {
class Connection {
public:
// TODO: update construct function, speed: double->uint64_t
Connection(std::string name, double speed)
: name_(std::move(name)), speed_(speed) {}
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <memory>
namespace zilliz {
namespace milvus {
namespace engine {
class RegisterHandler {
public:
virtual void Exec() = 0;
};
using RegisterHandlerPtr = std::shared_ptr<RegisterHandler>;
}
}
}
\ No newline at end of file
......@@ -12,7 +12,8 @@ namespace zilliz {
namespace milvus {
namespace engine {
std::ostream &operator<<(std::ostream &out, const Resource &resource) {
std::ostream &
operator<<(std::ostream &out, const Resource &resource) {
out << resource.Dump();
return out;
}
......@@ -25,11 +26,9 @@ Resource::Resource(std::string name,
: name_(std::move(name)),
type_(type),
device_id_(device_id),
running_(false),
enable_loader_(enable_loader),
enable_executor_(enable_executor),
load_flag_(false),
exec_flag_(false) {
enable_executor_(enable_executor) {
// register subscriber in tasktable
task_table_.RegisterSubscriber([&] {
if (subscriber_) {
auto event = std::make_shared<TaskTableUpdatedEvent>(shared_from_this());
......@@ -38,7 +37,8 @@ Resource::Resource(std::string name,
});
}
void Resource::Start() {
void
Resource::Start() {
running_ = true;
if (enable_loader_) {
loader_thread_ = std::thread(&Resource::loader_function, this);
......@@ -48,7 +48,8 @@ void Resource::Start() {
}
}
void Resource::Stop() {
void
Resource::Stop() {
running_ = false;
if (enable_loader_) {
WakeupLoader();
......@@ -60,11 +61,8 @@ void Resource::Stop() {
}
}
TaskTable &Resource::task_table() {
return task_table_;
}
void Resource::WakeupLoader() {
void
Resource::WakeupLoader() {
{
std::lock_guard<std::mutex> lock(load_mutex_);
load_flag_ = true;
......@@ -72,7 +70,8 @@ void Resource::WakeupLoader() {
load_cv_.notify_one();
}
void Resource::WakeupExecutor() {
void
Resource::WakeupExecutor() {
{
std::lock_guard<std::mutex> lock(exec_mutex_);
exec_flag_ = true;
......@@ -80,6 +79,15 @@ void Resource::WakeupExecutor() {
exec_cv_.notify_one();
}
uint64_t
Resource::NumOfTaskToExec() {
uint64_t count = 0;
for (auto &task : task_table_) {
if (task->state == TaskTableItemState::LOADED) ++count;
}
return count;
}
TaskTableItemPtr Resource::pick_task_load() {
auto indexes = task_table_.PickToLoad(10);
for (auto index : indexes) {
......@@ -156,11 +164,6 @@ void Resource::executor_function() {
}
}
RegisterHandlerPtr Resource::GetRegisterFunc(const RegisterType &type) {
// construct object each time.
return register_table_[type]();
}
}
}
}
\ No newline at end of file
......@@ -21,7 +21,6 @@
#include "../task/Task.h"
#include "Connection.h"
#include "Node.h"
#include "RegisterHandler.h"
namespace zilliz {
......@@ -35,13 +34,6 @@ enum class ResourceType {
GPU = 2
};
enum class RegisterType {
START_UP,
ON_FINISH_TASK,
ON_COPY_COMPLETED,
ON_TASK_TABLE_UPDATED,
};
class Resource : public Node, public std::enable_shared_from_this<Resource> {
public:
/*
......@@ -68,56 +60,51 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
void
WakeupExecutor();
public:
template<typename T>
void Register_T(const RegisterType &type) {
register_table_.emplace(type, [] { return std::make_shared<T>(); });
}
RegisterHandlerPtr
GetRegisterFunc(const RegisterType &type);
inline void
RegisterSubscriber(std::function<void(EventPtr)> subscriber) {
subscriber_ = std::move(subscriber);
}
inline virtual std::string
Dump() const {
return "<Resource>";
}
public:
inline std::string
Name() const {
name() const {
return name_;
}
inline ResourceType
Type() const {
type() const {
return type_;
}
inline uint64_t
DeviceId() {
device_id() const {
return device_id_;
}
// TODO: better name?
TaskTable &
task_table() {
return task_table_;
}
public:
inline bool
HasLoader() {
HasLoader() const {
return enable_loader_;
}
// TODO: better name?
inline bool
HasExecutor() {
HasExecutor() const {
return enable_executor_;
}
// TODO: const
uint64_t
NumOfTaskToExec() {
uint64_t count = 0;
for (auto &task : task_table_) {
if (task->state == TaskTableItemState::LOADED) ++count;
}
return count;
}
NumOfTaskToExec();
// TODO: need double ?
inline uint64_t
......@@ -130,14 +117,6 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
return total_task_;
}
TaskTable &
task_table();
inline virtual std::string
Dump() const {
return "<Resource>";
}
friend std::ostream &operator<<(std::ostream &out, const Resource &resource);
protected:
......@@ -198,6 +177,7 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
protected:
uint64_t device_id_;
std::string name_;
private:
ResourceType type_;
......@@ -206,17 +186,16 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
uint64_t total_cost_ = 0;
uint64_t total_task_ = 0;
std::map<RegisterType, std::function<RegisterHandlerPtr()>> register_table_;
std::function<void(EventPtr)> subscriber_ = nullptr;
bool running_;
bool running_ = false;
bool enable_loader_ = true;
bool enable_executor_ = true;
std::thread loader_thread_;
std::thread executor_thread_;
bool load_flag_;
bool exec_flag_;
bool load_flag_ = false;
bool exec_flag_ = false;
std::mutex load_mutex_;
std::mutex exec_mutex_;
std::condition_variable load_cv_;
......
......@@ -24,12 +24,6 @@ XDeleteTask::Execute() {
delete_context_ptr_->ResourceDone();
}
TaskPtr
XDeleteTask::Clone() {
auto task = std::make_shared<XDeleteTask>(delete_context_ptr_);
return task;
}
}
}
}
......@@ -24,9 +24,6 @@ public:
void
Execute() override;
TaskPtr
Clone() override;
public:
DeleteContextPtr delete_context_ptr_;
};
......
......@@ -163,6 +163,7 @@ XSearchTask::Execute() {
double span = rc.RecordSection("do search for context:" + context->Identity());
context->AccumSearchCost(span);
//step 3: cluster result
SearchContext::ResultSet result_set;
auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
......@@ -176,7 +177,6 @@ XSearchTask::Execute() {
span = rc.RecordSection("reduce topk for context:" + context->Identity());
context->AccumReduceCost(span);
} catch (std::exception &ex) {
ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
......@@ -193,16 +193,6 @@ XSearchTask::Execute() {
index_engine_ = nullptr;
}
TaskPtr
XSearchTask::Clone() {
auto ret = std::make_shared<XSearchTask>(file_);
ret->index_id_ = index_id_;
ret->index_engine_ = index_engine_->Clone();
ret->search_contexts_ = search_contexts_;
ret->metric_l2 = metric_l2;
return ret;
}
Status XSearchTask::ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
uint64_t nq,
......
......@@ -23,9 +23,6 @@ public:
void
Execute() override;
TaskPtr
Clone() override;
public:
static Status ClusterResult(const std::vector<long> &output_ids,
const std::vector<float> &output_distence,
......
......@@ -68,14 +68,9 @@ public:
virtual void
Execute() = 0;
// TODO: dont use this method to support task move
virtual TaskPtr
Clone() = 0;
public:
Path task_path_;
std::vector<SearchContextPtr> search_contexts_;
ScheduleTaskPtr task_;
TaskType type_;
TaskLabelPtr label_ = nullptr;
};
......
......@@ -21,7 +21,6 @@ TaskConvert(const ScheduleTaskPtr &schedule_task) {
auto task = std::make_shared<XSearchTask>(load_task->file_);
task->label() = std::make_shared<DefaultLabel>();
task->search_contexts_ = load_task->search_contexts_;
task->task_ = schedule_task;
return task;
}
case ScheduleTaskType::kDelete: {
......
......@@ -27,15 +27,6 @@ TestTask::Execute() {
done_ = true;
}
TaskPtr
TestTask::Clone() {
TableFileSchemaPtr dummy = nullptr;
auto ret = std::make_shared<TestTask>(dummy);
ret->load_count_ = load_count_;
ret->exec_count_ = exec_count_;
return ret;
}
void
TestTask::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
......
......@@ -23,9 +23,6 @@ public:
void
Execute() override;
TaskPtr
Clone() override;
void
Wait();
......
......@@ -25,6 +25,7 @@ public:
}
protected:
explicit
TaskLabel(TaskLabelType type) : type_(type) {}
private:
......
......@@ -96,6 +96,7 @@ TableSchema BuildTableSchema() {
tb_schema.table_name = TABLE_NAME;
tb_schema.dimension = TABLE_DIMENSION;
tb_schema.index_file_size = TABLE_INDEX_FILE_SIZE;
tb_schema.metric_type = MetricType::L2;
return tb_schema;
}
......@@ -291,7 +292,6 @@ ClientTest::Test(const std::string& address, const std::string& port) {
index.table_name = TABLE_NAME;
index.index_type = IndexType::gpu_ivfflat;
index.nlist = 16384;
index.metric_type = 1;
Status stat = conn->CreateIndex(index);
std::cout << "CreateIndex function call status: " << stat.ToString() << std::endl;
......
......@@ -84,6 +84,7 @@ ClientProxy::CreateTable(const TableSchema &param) {
schema.mutable_table_name()->set_table_name(param.table_name);
schema.set_dimension(param.dimension);
schema.set_index_file_size(param.index_file_size);
schema.set_metric_type((int32_t)param.metric_type);
return client_ptr_->CreateTable(schema);
} catch (std::exception &ex) {
......@@ -116,11 +117,9 @@ ClientProxy::CreateIndex(const IndexParam &index_param) {
try {
//TODO:add index params
::milvus::grpc::IndexParam grpc_index_param;
grpc_index_param.mutable_table_name()->set_table_name(
index_param.table_name);
grpc_index_param.mutable_table_name()->set_table_name(index_param.table_name);
grpc_index_param.mutable_index()->set_index_type((int32_t)index_param.index_type);
grpc_index_param.mutable_index()->set_nlist(index_param.nlist);
grpc_index_param.mutable_index()->set_metric_type(index_param.metric_type);
return client_ptr_->CreateIndex(grpc_index_param);
} catch (std::exception &ex) {
......@@ -240,16 +239,16 @@ ClientProxy::Search(const std::string &table_name,
}
//step 3: search vectors
std::vector<::milvus::grpc::TopKQueryResult> result_array;
Status status = client_ptr_->Search(result_array, search_param);
::milvus::grpc::TopKQueryResultList topk_query_result_list;
Status status = client_ptr_->Search(topk_query_result_list, search_param);
//step 4: convert result array
for (auto &grpc_topk_result : result_array) {
for (uint64_t i = 0; i < topk_query_result_list.topk_query_result_size(); ++i) {
TopKQueryResult result;
for (size_t i = 0; i < grpc_topk_result.query_result_arrays_size(); i++) {
for (uint64_t j = 0; j < topk_query_result_list.topk_query_result(i).query_result_arrays_size(); ++j) {
QueryResult query_result;
query_result.id = grpc_topk_result.query_result_arrays(i).id();
query_result.distance = grpc_topk_result.query_result_arrays(i).distance();
query_result.id = topk_query_result_list.topk_query_result(i).query_result_arrays(j).id();
query_result.distance = topk_query_result_list.topk_query_result(i).query_result_arrays(j).distance();
result.query_result_arrays.emplace_back(query_result);
}
......@@ -273,6 +272,7 @@ ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_sch
table_schema.table_name = grpc_schema.table_name().table_name();
table_schema.dimension = grpc_schema.dimension();
table_schema.index_file_size = grpc_schema.index_file_size();
table_schema.metric_type = (MetricType)grpc_schema.metric_type();
return status;
} catch (std::exception &ex) {
......@@ -378,7 +378,6 @@ ClientProxy::DescribeIndex(const std::string &table_name, IndexParam &index_para
Status status = client_ptr_->DescribeIndex(grpc_table_name, grpc_index_param);
index_param.index_type = (IndexType)(grpc_index_param.mutable_index()->index_type());
index_param.nlist = grpc_index_param.mutable_index()->nlist();
index_param.metric_type = grpc_index_param.mutable_index()->metric_type();
return status;
......
......@@ -121,28 +121,21 @@ GrpcClient::Insert(::milvus::grpc::VectorIds& vector_ids,
}
Status
GrpcClient::Search(std::vector<::milvus::grpc::TopKQueryResult>& result_array,
const ::milvus::grpc::SearchParam& search_param) {
GrpcClient::Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list,
const ::milvus::grpc::SearchParam &search_param) {
::milvus::grpc::TopKQueryResult query_result;
ClientContext context;
std::unique_ptr<ClientReader<::milvus::grpc::TopKQueryResult> > reader(
stub_->Search(&context, search_param));
while (reader->Read(&query_result)) {
result_array.emplace_back(query_result);
}
::grpc::Status grpc_status = reader->Finish();
::grpc::Status grpc_status = stub_->Search(&context, search_param, &topk_query_result_list);
if (!grpc_status.ok()) {
std::cerr << "SearchVector rpc failed!" << std::endl;
std::cerr << grpc_status.error_message() << std::endl;
return Status(StatusCode::RPCFailed, grpc_status.error_message());
}
if (query_result.status().error_code() != grpc::SUCCESS) {
std::cerr << query_result.status().reason() << std::endl;
if (topk_query_result_list.status().error_code() != grpc::SUCCESS) {
std::cerr << topk_query_result_list.status().reason() << std::endl;
return Status(StatusCode::ServerFailed,
query_result.status().reason());
topk_query_result_list.status().reason());
}
return Status::OK();
......
......@@ -50,8 +50,8 @@ public:
Status& status);
Status
Search(std::vector<grpc::TopKQueryResult>& result_array,
const grpc::SearchParam& search_param);
Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list,
const grpc::SearchParam &search_param);
Status
DescribeTable(grpc::TableSchema& grpc_schema,
......
......@@ -22,6 +22,11 @@ enum class IndexType {
mix_nsg,
};
enum class MetricType {
L2 = 1,
IP = 2,
};
/**
* @brief Connect API parameter
*/
......@@ -37,6 +42,7 @@ struct TableSchema {
std::string table_name; ///< Table name
int64_t dimension = 0; ///< Vector dimension, must be a positive value
int64_t index_file_size = 0; ///< Index file size, must be a positive value
MetricType metric_type = MetricType::L2; ///< Index metric type
};
/**
......@@ -77,7 +83,6 @@ struct IndexParam {
std::string table_name;
IndexType index_type;
int32_t nlist;
int32_t metric_type;
};
/**
......
......@@ -6,6 +6,7 @@
#include "DBWrapper.h"
#include "ServerConfig.h"
#include "db/Factories.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/StringHelpFunctions.h"
......@@ -95,8 +96,7 @@ ServerError DBWrapper::StartService() {
//create db instance
std::string msg = opt.meta.path;
try {
engine::DB* db = nullptr;
zilliz::milvus::engine::DB::Open(opt, &db);
engine::DB* db = engine::DBFactory::Build(opt);
db_.reset(db);
} catch(std::exception& ex) {
msg = ex.what();
......
......@@ -56,6 +56,9 @@ static const char* CONFIG_RESOURCE_MEMORY = "memory";
static const char* CONFIG_RESOURCE_DEVICE_ID = "device_id";
static const char* CONFIG_RESOURCE_ENABLE_LOADER = "enable_loader";
static const char* CONFIG_RESOURCE_ENABLE_EXECUTOR = "enable_executor";
static const char* CONFIG_RESOURCE_NUM = "gpu_resource_num";
static const char* CONFIG_RESOURCE_PIN_MEMORY = "pinned_memory";
static const char* CONFIG_RESOURCE_TEMP_MEMORY = "temp_memory";
static const char* CONFIG_RESOURCE_CONNECTIONS = "connections";
static const char* CONFIG_SPEED_CONNECTIONS = "speed";
static const char* CONFIG_ENDPOINT_CONNECTIONS = "endpoint";
......
......@@ -25,6 +25,7 @@
#include <grpcpp/security/credentials.h>
#include <grpcpp/grpcpp.h>
namespace zilliz {
namespace milvus {
namespace server {
......@@ -36,11 +37,11 @@ constexpr long MESSAGE_SIZE = -1;
class NoReusePortOption : public ::grpc::ServerBuilderOption {
public:
void UpdateArguments(::grpc::ChannelArguments* args) override {
void UpdateArguments(::grpc::ChannelArguments *args) override {
args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
}
void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> *
plugins) override {}
};
......@@ -78,6 +79,7 @@ GrpcMilvusServer::StartService() {
server = builder.BuildAndStart();
server->Wait();
}
void
......
......@@ -49,7 +49,7 @@ endforeach()
add_subdirectory(server)
add_subdirectory(db)
add_subdirectory(knowhere)
#add_subdirectory(knowhere)
add_subdirectory(metrics)
#add_subdirectory(scheduler)
#add_subdirectory(storage)
\ No newline at end of file
......@@ -90,5 +90,5 @@ set(knowhere_libs
target_link_libraries(db_test ${db_libs} ${knowhere_libs} ${unittest_libs})
install(TARGETS db_test DESTINATION bin)
install(TARGETS db_test DESTINATION unittest)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册