提交 6dd89bb0 编写于 作者: X xj.lin

MS-534 1. fix


Former-commit-id: 6d1927610d68cdb2689070699e42851223ee67d9
上级 724be3da
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <faiss/gpu/StandardGpuResources.h> #include <faiss/gpu/StandardGpuResources.h>
#include "ivf.h" #include "ivf.h"
#include "src/utils/BlockingQueue.h"
namespace zilliz { namespace zilliz {
...@@ -16,12 +17,15 @@ struct Resource { ...@@ -16,12 +17,15 @@ struct Resource {
std::shared_ptr<faiss::gpu::StandardGpuResources> faiss_res; std::shared_ptr<faiss::gpu::StandardGpuResources> faiss_res;
int64_t id; int64_t id;
std::mutex mutex;
}; };
using ResPtr = std::shared_ptr<Resource>; using ResPtr = std::shared_ptr<Resource>;
using ResWPtr = std::weak_ptr<Resource>; using ResWPtr = std::weak_ptr<Resource>;
class FaissGpuResourceMgr { class FaissGpuResourceMgr {
public: public:
using ResBQ = zilliz::milvus::server::BlockingQueue<ResPtr>;
struct DeviceParams { struct DeviceParams {
int64_t temp_mem_size = 0; int64_t temp_mem_size = 0;
int64_t pinned_mem_size = 0; int64_t pinned_mem_size = 0;
...@@ -55,11 +59,8 @@ class FaissGpuResourceMgr { ...@@ -55,11 +59,8 @@ class FaissGpuResourceMgr {
// allocate gpu memory before search // allocate gpu memory before search
// this func will return True if the device is idle and exists an idle resource. // this func will return True if the device is idle and exists an idle resource.
bool //bool
GetRes(const int64_t& device_id, ResPtr &res, const int64_t& alloc_size = 0); //GetRes(const int64_t& device_id, ResPtr &res, const int64_t& alloc_size = 0);
void
MoveToInuse(const int64_t &device_id, const ResPtr& res);
void void
MoveToIdle(const int64_t &device_id, const ResPtr& res); MoveToIdle(const int64_t &device_id, const ResPtr& res);
...@@ -67,33 +68,34 @@ class FaissGpuResourceMgr { ...@@ -67,33 +68,34 @@ class FaissGpuResourceMgr {
void void
Dump(); Dump();
protected:
void
RemoveResource(const int64_t& device_id, const ResPtr& res, std::map<int64_t, std::vector<ResPtr>>& resource_pool);
protected: protected:
bool is_init = false; bool is_init = false;
std::mutex mutex_;
std::map<int64_t, DeviceParams> devices_params_; std::map<int64_t, DeviceParams> devices_params_;
std::map<int64_t, std::vector<ResPtr>> in_use_; std::map<int64_t, ResBQ> idle_map;
std::map<int64_t, std::vector<ResPtr>> idle_;
}; };
class ResScope { class ResScope {
public: public:
ResScope(const int64_t device_id, ResPtr &res) : resource(res), device_id(device_id) { ResScope(const int64_t device_id, ResPtr &res) : resource(res), device_id(device_id), move(true) {
FaissGpuResourceMgr::GetInstance().MoveToInuse(device_id, resource); res->mutex.lock();
}
ResScope(ResPtr &res) : resource(res), device_id(-1), move(false) {
res->mutex.lock();
} }
~ResScope() { ~ResScope() {
//resource->faiss_res->noTempMemory(); if (move) {
FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource); FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource);
} }
resource->mutex.unlock();
}
private: private:
ResPtr resource; ResPtr resource;
int64_t device_id; int64_t device_id;
bool move = true;
}; };
class GPUIndex { class GPUIndex {
......
...@@ -130,19 +130,17 @@ void GPUIVF::search_impl(int64_t n, ...@@ -130,19 +130,17 @@ void GPUIVF::search_impl(int64_t n,
float *distances, float *distances,
int64_t *labels, int64_t *labels,
const Config &cfg) { const Config &cfg) {
// TODO(linxj): allocate mem std::lock_guard<std::mutex> lk(mutex_);
auto temp_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_);
if (temp_res) {
ResScope rs(gpu_id_, temp_res);
if (auto device_index = std::static_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) { if (auto device_index = std::static_pointer_cast<faiss::gpu::GpuIndexIVF>(index_)) {
auto nprobe = cfg.get_with_default("nprobe", size_t(1)); auto nprobe = cfg.get_with_default("nprobe", size_t(1));
std::lock_guard<std::mutex> lk(mutex_);
device_index->setNumProbes(nprobe); device_index->setNumProbes(nprobe);
{
// TODO(linxj): allocate mem
ResScope rs(res_);
device_index->search(n, (float *) data, k, distances, labels); device_index->search(n, (float *) data, k, distances, labels);
} }
} else {
KNOWHERE_THROW_MSG("search can't get gpu resource");
} }
} }
...@@ -283,119 +281,70 @@ void FaissGpuResourceMgr::InitResource() { ...@@ -283,119 +281,70 @@ void FaissGpuResourceMgr::InitResource() {
is_init = true; is_init = true;
for(auto& device : devices_params_) { for(auto& device : devices_params_) {
auto& resource_vec = idle_[device.first]; auto& device_id = device.first;
auto& device_param = device.second;
auto& bq = idle_map[device_id];
for (int64_t i = 0; i < device.second.resource_num; ++i) { for (int64_t i = 0; i < device_param.resource_num; ++i) {
auto res = std::make_shared<faiss::gpu::StandardGpuResources>(); auto raw_resource = std::make_shared<faiss::gpu::StandardGpuResources>();
// TODO(linxj): enable set pinned memory // TODO(linxj): enable set pinned memory
//res->noTempMemory(); auto res_wrapper = std::make_shared<Resource>(raw_resource);
auto res_wrapper = std::make_shared<Resource>(res); AllocateTempMem(res_wrapper, device_id, 0);
AllocateTempMem(res_wrapper, device.first, 0);
resource_vec.emplace_back(res_wrapper); bq.Put(res_wrapper);
} }
} }
} }
ResPtr FaissGpuResourceMgr::GetRes(const int64_t &device_id, ResPtr FaissGpuResourceMgr::GetRes(const int64_t &device_id,
const int64_t &alloc_size) { const int64_t &alloc_size) {
std::lock_guard<std::mutex> lk(mutex_);
if (!is_init) {
InitResource(); InitResource();
is_init = true;
}
auto search = idle_.find(device_id);
if (search != idle_.end()) {
auto res = search->second.back();
//AllocateTempMem(res, device_id, alloc_size);
search->second.pop_back(); auto finder = idle_map.find(device_id);
return res; if (finder != idle_map.end()) {
auto& bq = finder->second;
auto&& resource = bq.Take();
AllocateTempMem(resource, device_id, alloc_size);
return resource;
} }
return nullptr; return nullptr;
} }
bool FaissGpuResourceMgr::GetRes(const int64_t &device_id, //bool FaissGpuResourceMgr::GetRes(const int64_t &device_id,
ResPtr &res, // ResPtr &res,
const int64_t &alloc_size) { // const int64_t &alloc_size) {
std::lock_guard<std::mutex> lk(mutex_); // InitResource();
//
if (!is_init) { // std::lock_guard<std::mutex> lk(res->mutex);
InitResource(); // AllocateTempMem(res, device_id, alloc_size);
is_init = true; // return true;
} //}
auto search = idle_.find(device_id);
if (search != idle_.end()) {
auto &res_vec = search->second;
for (auto it = res_vec.cbegin(); it != res_vec.cend(); ++it) {
if ((*it)->id == res->id) {
//AllocateTempMem(res, device_id, alloc_size);
res_vec.erase(it);
return true;
}
}
}
// else
return false;
}
void FaissGpuResourceMgr::MoveToInuse(const int64_t &device_id, const ResPtr &res) {
std::lock_guard<std::mutex> lk(mutex_);
RemoveResource(device_id, res, idle_);
in_use_[device_id].push_back(res);
}
void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const ResPtr &res) { void FaissGpuResourceMgr::MoveToIdle(const int64_t &device_id, const ResPtr &res) {
std::lock_guard<std::mutex> lk(mutex_); auto finder = idle_map.find(device_id);
RemoveResource(device_id, res, in_use_); if (finder != idle_map.end()) {
auto it = idle_[device_id].begin(); auto& bq = finder->second;
idle_[device_id].insert(it, res); bq.Put(res);
}
void
FaissGpuResourceMgr::RemoveResource(const int64_t &device_id,
const ResPtr &res,
std::map<int64_t, std::vector<ResPtr>> &resource_pool) {
if (resource_pool.find(device_id) != resource_pool.end()) {
std::vector<ResPtr> &res_array = resource_pool[device_id];
res_array.erase(std::remove_if(res_array.begin(), res_array.end(),
[&](ResPtr &ptr) { return ptr->id == res->id; }),
res_array.end());
} }
} }
void FaissGpuResourceMgr::Free() { void FaissGpuResourceMgr::Free() {
for (auto &item : in_use_) { for (auto &item : idle_map) {
auto& res_vec = item.second; auto& bq = item.second;
res_vec.clear(); while (!bq.Empty()) {
bq.Take();
} }
for (auto &item : idle_) {
auto& res_vec = item.second;
res_vec.clear();
} }
is_init = false; is_init = false;
} }
void void
FaissGpuResourceMgr::Dump() { FaissGpuResourceMgr::Dump() {
std::cout << "In used resource" << std::endl; for (auto &item : idle_map) {
for(auto& item: in_use_) { auto& bq = item.second;
std::cout << "device_id: " << item.first << std::endl; std::cout << "device_id: " << item.first
for(auto& elem : item.second) { << ", resource count:" << bq.Size();
std::cout << "resource_id: " << elem->id << std::endl;
}
}
std::cout << "Idle resource" << std::endl;
for(auto& item: idle_) {
std::cout << "device_id: " << item.first << std::endl;
for(auto& elem : item.second) {
std::cout << "resource_id: " << elem->id << std::endl;
}
} }
} }
......
...@@ -386,7 +386,7 @@ class GPURESTEST ...@@ -386,7 +386,7 @@ class GPURESTEST
int64_t elems = 0; int64_t elems = 0;
}; };
const int search_count = 10; const int search_count = 18;
const int load_count = 3; const int load_count = 3;
TEST_F(GPURESTEST, gpu_ivf_resource_test) { TEST_F(GPURESTEST, gpu_ivf_resource_test) {
......
#pragma once #pragma once
#include "Log.h" //#include "Log.h"
#include "Error.h" #include "Error.h"
namespace zilliz { namespace zilliz {
...@@ -17,7 +17,7 @@ BlockingQueue<T>::Put(const T &task) { ...@@ -17,7 +17,7 @@ BlockingQueue<T>::Put(const T &task) {
std::string error_msg = std::string error_msg =
"blocking queue is full, capacity: " + std::to_string(capacity_) + " queue_size: " + "blocking queue is full, capacity: " + std::to_string(capacity_) + " queue_size: " +
std::to_string(queue_.size()); std::to_string(queue_.size());
SERVER_LOG_ERROR << error_msg; //SERVER_LOG_ERROR << error_msg;
throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg); throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg);
} }
...@@ -33,7 +33,7 @@ BlockingQueue<T>::Take() { ...@@ -33,7 +33,7 @@ BlockingQueue<T>::Take() {
if (queue_.empty()) { if (queue_.empty()) {
std::string error_msg = "blocking queue empty"; std::string error_msg = "blocking queue empty";
SERVER_LOG_ERROR << error_msg; //SERVER_LOG_ERROR << error_msg;
throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg); throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg);
} }
...@@ -57,7 +57,7 @@ BlockingQueue<T>::Front() { ...@@ -57,7 +57,7 @@ BlockingQueue<T>::Front() {
empty_.wait(lock, [this] { return !queue_.empty(); }); empty_.wait(lock, [this] { return !queue_.empty(); });
if (queue_.empty()) { if (queue_.empty()) {
std::string error_msg = "blocking queue empty"; std::string error_msg = "blocking queue empty";
SERVER_LOG_ERROR << error_msg; //SERVER_LOG_ERROR << error_msg;
throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg); throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg);
} }
T front(queue_.front()); T front(queue_.front());
...@@ -72,7 +72,7 @@ BlockingQueue<T>::Back() { ...@@ -72,7 +72,7 @@ BlockingQueue<T>::Back() {
if (queue_.empty()) { if (queue_.empty()) {
std::string error_msg = "blocking queue empty"; std::string error_msg = "blocking queue empty";
SERVER_LOG_ERROR << error_msg; //SERVER_LOG_ERROR << error_msg;
throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg); throw ServerException(SERVER_BLOCKING_QUEUE_EMPTY, error_msg);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册