提交 2705fa83 编写于 作者: P peng.xu

Merge branch 'branch-0.5.0' into 'branch-0.5.0'

format code

See merge request megasearch/milvus!625

Former-commit-id: adfbacea676947f3656543765eea3fbec796440e
......@@ -32,30 +32,40 @@ namespace cache {
template<typename ItemObj>
class Cache {
public:
public:
//mem_capacity, units:GB
Cache(int64_t capacity_gb, uint64_t cache_max_count);
~Cache() = default;
int64_t usage() const { return usage_; }
int64_t capacity() const { return capacity_; } //unit: BYTE
int64_t usage() const {
return usage_;
}
int64_t capacity() const {
return capacity_;
} //unit: BYTE
void set_capacity(int64_t capacity); //unit: BYTE
double freemem_percent() const { return freemem_percent_; };
void set_freemem_percent(double percent) { freemem_percent_ = percent; }
double freemem_percent() const {
return freemem_percent_;
}
void set_freemem_percent(double percent) {
freemem_percent_ = percent;
}
size_t size() const;
bool exists(const std::string& key);
ItemObj get(const std::string& key);
void insert(const std::string& key, const ItemObj& item);
void erase(const std::string& key);
bool exists(const std::string &key);
ItemObj get(const std::string &key);
void insert(const std::string &key, const ItemObj &item);
void erase(const std::string &key);
void print();
void clear();
private:
private:
void free_memory();
private:
private:
int64_t usage_;
int64_t capacity_;
double freemem_percent_;
......@@ -64,8 +74,8 @@ private:
mutable std::mutex mutex_;
};
} // cache
} // milvus
} // zilliz
} // namespace cache
} // namespace milvus
} // namespace zilliz
#include "cache/Cache.inl"
\ No newline at end of file
#include "cache/Cache.inl"
......@@ -33,29 +33,33 @@ Cache<ItemObj>::Cache(int64_t capacity, uint64_t cache_max_count)
}
template<typename ItemObj>
void Cache<ItemObj>::set_capacity(int64_t capacity) {
if(capacity > 0) {
void
Cache<ItemObj>::set_capacity(int64_t capacity) {
if (capacity > 0) {
capacity_ = capacity;
free_memory();
}
}
template<typename ItemObj>
size_t Cache<ItemObj>::size() const {
size_t
Cache<ItemObj>::size() const {
std::lock_guard<std::mutex> lock(mutex_);
return lru_.size();
}
template<typename ItemObj>
bool Cache<ItemObj>::exists(const std::string& key) {
bool
Cache<ItemObj>::exists(const std::string &key) {
std::lock_guard<std::mutex> lock(mutex_);
return lru_.exists(key);
}
template<typename ItemObj>
ItemObj Cache<ItemObj>::get(const std::string& key) {
ItemObj
Cache<ItemObj>::get(const std::string &key) {
std::lock_guard<std::mutex> lock(mutex_);
if(!lru_.exists(key)){
if (!lru_.exists(key)) {
return nullptr;
}
......@@ -63,8 +67,9 @@ ItemObj Cache<ItemObj>::get(const std::string& key) {
}
template<typename ItemObj>
void Cache<ItemObj>::insert(const std::string& key, const ItemObj& item) {
if(item == nullptr) {
void
Cache<ItemObj>::insert(const std::string &key, const ItemObj &item) {
if (item == nullptr) {
return;
}
......@@ -80,7 +85,7 @@ void Cache<ItemObj>::insert(const std::string& key, const ItemObj& item) {
//if key already exist, subtract old item size
if (lru_.exists(key)) {
const ItemObj& old_item = lru_.get(key);
const ItemObj &old_item = lru_.get(key);
usage_ -= old_item->size();
}
......@@ -107,13 +112,14 @@ void Cache<ItemObj>::insert(const std::string& key, const ItemObj& item) {
}
template<typename ItemObj>
void Cache<ItemObj>::erase(const std::string& key) {
void
Cache<ItemObj>::erase(const std::string &key) {
std::lock_guard<std::mutex> lock(mutex_);
if(!lru_.exists(key)){
if (!lru_.exists(key)) {
return;
}
const ItemObj& old_item = lru_.get(key);
const ItemObj &old_item = lru_.get(key);
usage_ -= old_item->size();
SERVER_LOG_DEBUG << "Erase " << key << " size: " << old_item->size();
......@@ -122,7 +128,8 @@ void Cache<ItemObj>::erase(const std::string& key) {
}
template<typename ItemObj>
void Cache<ItemObj>::clear() {
void
Cache<ItemObj>::clear() {
std::lock_guard<std::mutex> lock(mutex_);
lru_.clear();
usage_ = 0;
......@@ -131,12 +138,13 @@ void Cache<ItemObj>::clear() {
/* free memory space when CACHE occupation exceed its capacity */
template<typename ItemObj>
void Cache<ItemObj>::free_memory() {
void
Cache<ItemObj>::free_memory() {
if (usage_ <= capacity_) return;
int64_t threshhold = capacity_ * freemem_percent_;
int64_t delta_size = usage_ - threshhold;
if(delta_size <= 0) {
if (delta_size <= 0) {
delta_size = 1;//ensure at least one item erased
}
......@@ -148,8 +156,8 @@ void Cache<ItemObj>::free_memory() {
auto it = lru_.rbegin();
while (it != lru_.rend() && released_size < delta_size) {
auto& key = it->first;
auto& obj_ptr = it->second;
auto &key = it->first;
auto &obj_ptr = it->second;
key_array.emplace(key);
released_size += obj_ptr->size();
......@@ -159,7 +167,7 @@ void Cache<ItemObj>::free_memory() {
SERVER_LOG_DEBUG << "to be released memory size: " << released_size;
for (auto& key : key_array) {
for (auto &key : key_array) {
erase(key);
}
......@@ -167,7 +175,8 @@ void Cache<ItemObj>::free_memory() {
}
template<typename ItemObj>
void Cache<ItemObj>::print() {
void
Cache<ItemObj>::print() {
size_t cache_count = 0;
{
std::lock_guard<std::mutex> lock(mutex_);
......@@ -179,7 +188,7 @@ void Cache<ItemObj>::print() {
SERVER_LOG_DEBUG << "[Cache capacity]: " << capacity_ << " bytes";
}
} // cache
} // milvus
} // zilliz
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -22,22 +22,25 @@
#include "utils/Log.h"
#include "metrics/Metrics.h"
#include <string>
#include <memory>
namespace zilliz {
namespace milvus {
namespace cache {
template<typename ItemObj>
class CacheMgr {
public:
public:
virtual uint64_t ItemCount() const;
virtual bool ItemExists(const std::string& key);
virtual bool ItemExists(const std::string &key);
virtual ItemObj GetItem(const std::string& key);
virtual ItemObj GetItem(const std::string &key);
virtual void InsertItem(const std::string& key, const ItemObj& data);
virtual void InsertItem(const std::string &key, const ItemObj &data);
virtual void EraseItem(const std::string& key);
virtual void EraseItem(const std::string &key);
virtual void PrintInfo();
......@@ -47,18 +50,17 @@ public:
int64_t CacheCapacity() const;
void SetCapacity(int64_t capacity);
protected:
protected:
CacheMgr();
virtual ~CacheMgr();
protected:
protected:
using CachePtr = std::shared_ptr<Cache<ItemObj>>;
CachePtr cache_;
};
} // namespace cache
} // namespace milvus
} // namespace zilliz
}
}
}
#include "cache/CacheMgr.inl"
\ No newline at end of file
#include "cache/CacheMgr.inl"
......@@ -30,18 +30,20 @@ CacheMgr<ItemObj>::~CacheMgr() {
}
template<typename ItemObj>
uint64_t CacheMgr<ItemObj>::ItemCount() const {
if(cache_ == nullptr) {
uint64_t
CacheMgr<ItemObj>::ItemCount() const {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0;
}
return (uint64_t)(cache_->size());
return (uint64_t) (cache_->size());
}
template<typename ItemObj>
bool CacheMgr<ItemObj>::ItemExists(const std::string& key) {
if(cache_ == nullptr) {
bool
CacheMgr<ItemObj>::ItemExists(const std::string &key) {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return false;
}
......@@ -50,8 +52,9 @@ bool CacheMgr<ItemObj>::ItemExists(const std::string& key) {
}
template<typename ItemObj>
ItemObj CacheMgr<ItemObj>::GetItem(const std::string& key) {
if(cache_ == nullptr) {
ItemObj
CacheMgr<ItemObj>::GetItem(const std::string &key) {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return nullptr;
}
......@@ -60,8 +63,9 @@ ItemObj CacheMgr<ItemObj>::GetItem(const std::string& key) {
}
template<typename ItemObj>
void CacheMgr<ItemObj>::InsertItem(const std::string& key, const ItemObj& data) {
if(cache_ == nullptr) {
void
CacheMgr<ItemObj>::InsertItem(const std::string &key, const ItemObj &data) {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -71,8 +75,9 @@ void CacheMgr<ItemObj>::InsertItem(const std::string& key, const ItemObj& data)
}
template<typename ItemObj>
void CacheMgr<ItemObj>::EraseItem(const std::string& key) {
if(cache_ == nullptr) {
void
CacheMgr<ItemObj>::EraseItem(const std::string &key) {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -82,8 +87,9 @@ void CacheMgr<ItemObj>::EraseItem(const std::string& key) {
}
template<typename ItemObj>
void CacheMgr<ItemObj>::PrintInfo() {
if(cache_ == nullptr) {
void
CacheMgr<ItemObj>::PrintInfo() {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -92,8 +98,9 @@ void CacheMgr<ItemObj>::PrintInfo() {
}
template<typename ItemObj>
void CacheMgr<ItemObj>::ClearCache() {
if(cache_ == nullptr) {
void
CacheMgr<ItemObj>::ClearCache() {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
......@@ -102,8 +109,9 @@ void CacheMgr<ItemObj>::ClearCache() {
}
template<typename ItemObj>
int64_t CacheMgr<ItemObj>::CacheUsage() const {
if(cache_ == nullptr) {
int64_t
CacheMgr<ItemObj>::CacheUsage() const {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0;
}
......@@ -112,8 +120,9 @@ int64_t CacheMgr<ItemObj>::CacheUsage() const {
}
template<typename ItemObj>
int64_t CacheMgr<ItemObj>::CacheCapacity() const {
if(cache_ == nullptr) {
int64_t
CacheMgr<ItemObj>::CacheCapacity() const {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return 0;
}
......@@ -122,14 +131,15 @@ int64_t CacheMgr<ItemObj>::CacheCapacity() const {
}
template<typename ItemObj>
void CacheMgr<ItemObj>::SetCapacity(int64_t capacity) {
if(cache_ == nullptr) {
void
CacheMgr<ItemObj>::SetCapacity(int64_t capacity) {
if (cache_ == nullptr) {
SERVER_LOG_ERROR << "Cache doesn't exist";
return;
}
cache_->set_capacity(capacity);
}
}
}
}
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -16,20 +16,22 @@
// under the License.
#include "CpuCacheMgr.h"
#include "cache/CpuCacheMgr.h"
#include "server/Config.h"
#include "utils/Log.h"
#include <utility>
namespace zilliz {
namespace milvus {
namespace cache {
namespace {
constexpr int64_t unit = 1024 * 1024 * 1024;
constexpr int64_t unit = 1024 * 1024 * 1024;
}
CpuCacheMgr::CpuCacheMgr() {
server::Config& config = server::Config::GetInstance();
server::Config &config = server::Config::GetInstance();
Status s;
int32_t cpu_mem_cap;
......@@ -38,7 +40,7 @@ CpuCacheMgr::CpuCacheMgr() {
SERVER_LOG_ERROR << s.message();
}
int64_t cap = cpu_mem_cap * unit;
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL<<32);
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32);
float cpu_mem_threshold;
s = config.GetCacheConfigCpuMemThreshold(cpu_mem_threshold);
......@@ -53,20 +55,22 @@ CpuCacheMgr::CpuCacheMgr() {
}
}
CpuCacheMgr* CpuCacheMgr::GetInstance() {
CpuCacheMgr *
CpuCacheMgr::GetInstance() {
static CpuCacheMgr s_mgr;
return &s_mgr;
}
engine::VecIndexPtr CpuCacheMgr::GetIndex(const std::string& key) {
engine::VecIndexPtr
CpuCacheMgr::GetIndex(const std::string &key) {
DataObjPtr obj = GetItem(key);
if(obj != nullptr) {
if (obj != nullptr) {
return obj->data();
}
return nullptr;
}
}
}
}
\ No newline at end of file
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -20,21 +20,24 @@
#include "CacheMgr.h"
#include "DataObj.h"
#include <string>
#include <memory>
namespace zilliz {
namespace milvus {
namespace cache {
class CpuCacheMgr : public CacheMgr<DataObjPtr> {
private:
private:
CpuCacheMgr();
public:
public:
//TODO: use smart pointer instead
static CpuCacheMgr* GetInstance();
static CpuCacheMgr *GetInstance();
engine::VecIndexPtr GetIndex(const std::string& key);
engine::VecIndexPtr GetIndex(const std::string &key);
};
}
}
}
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -27,38 +27,43 @@ namespace milvus {
namespace cache {
class DataObj {
public:
DataObj(const engine::VecIndexPtr& index)
: index_(index)
{}
public:
explicit DataObj(const engine::VecIndexPtr &index)
: index_(index) {
}
DataObj(const engine::VecIndexPtr& index, int64_t size)
: index_(index),
size_(size)
{}
DataObj(const engine::VecIndexPtr &index, int64_t size)
: index_(index),
size_(size) {
}
engine::VecIndexPtr data() { return index_; }
const engine::VecIndexPtr& data() const { return index_; }
engine::VecIndexPtr data() {
return index_;
}
const engine::VecIndexPtr &data() const {
return index_;
}
int64_t size() const {
if(index_ == nullptr) {
if (index_ == nullptr) {
return 0;
}
if(size_ > 0) {
if (size_ > 0) {
return size_;
}
return index_->Count() * index_->Dimension() * sizeof(float);
}
private:
private:
engine::VecIndexPtr index_ = nullptr;
int64_t size_ = 0;
};
using DataObjPtr = std::shared_ptr<DataObj>;
}
}
}
\ No newline at end of file
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -16,11 +16,13 @@
// under the License.
#include <sstream>
#include "cache/GpuCacheMgr.h"
#include "utils/Log.h"
#include "GpuCacheMgr.h"
#include "server/Config.h"
#include <sstream>
#include <utility>
namespace zilliz {
namespace milvus {
namespace cache {
......@@ -29,11 +31,11 @@ std::mutex GpuCacheMgr::mutex_;
std::unordered_map<uint64_t, GpuCacheMgrPtr> GpuCacheMgr::instance_;
namespace {
constexpr int64_t G_BYTE = 1024 * 1024 * 1024;
constexpr int64_t G_BYTE = 1024 * 1024 * 1024;
}
GpuCacheMgr::GpuCacheMgr() {
server::Config& config = server::Config::GetInstance();
server::Config &config = server::Config::GetInstance();
Status s;
int32_t gpu_mem_cap;
......@@ -42,7 +44,7 @@ GpuCacheMgr::GpuCacheMgr() {
SERVER_LOG_ERROR << s.message();
}
int32_t cap = gpu_mem_cap * G_BYTE;
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL<<32);
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32);
float gpu_mem_threshold;
s = config.GetCacheConfigGpuMemThreshold(gpu_mem_threshold);
......@@ -57,7 +59,8 @@ GpuCacheMgr::GpuCacheMgr() {
}
}
GpuCacheMgr* GpuCacheMgr::GetInstance(uint64_t gpu_id) {
GpuCacheMgr *
GpuCacheMgr::GetInstance(uint64_t gpu_id) {
if (instance_.find(gpu_id) == instance_.end()) {
std::lock_guard<std::mutex> lock(mutex_);
if (instance_.find(gpu_id) == instance_.end()) {
......@@ -70,15 +73,16 @@ GpuCacheMgr* GpuCacheMgr::GetInstance(uint64_t gpu_id) {
}
}
engine::VecIndexPtr GpuCacheMgr::GetIndex(const std::string& key) {
engine::VecIndexPtr
GpuCacheMgr::GetIndex(const std::string &key) {
DataObjPtr obj = GetItem(key);
if(obj != nullptr) {
if (obj != nullptr) {
return obj->data();
}
return nullptr;
}
}
}
}
\ No newline at end of file
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -21,6 +21,7 @@
#include <unordered_map>
#include <memory>
#include <string>
namespace zilliz {
namespace milvus {
......@@ -30,18 +31,18 @@ class GpuCacheMgr;
using GpuCacheMgrPtr = std::shared_ptr<GpuCacheMgr>;
class GpuCacheMgr : public CacheMgr<DataObjPtr> {
public:
public:
GpuCacheMgr();
static GpuCacheMgr* GetInstance(uint64_t gpu_id);
static GpuCacheMgr *GetInstance(uint64_t gpu_id);
engine::VecIndexPtr GetIndex(const std::string& key);
engine::VecIndexPtr GetIndex(const std::string &key);
private:
private:
static std::mutex mutex_;
static std::unordered_map<uint64_t, GpuCacheMgrPtr> instance_;
};
}
}
}
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -22,6 +22,7 @@
#include <list>
#include <cstddef>
#include <stdexcept>
#include <utility>
namespace zilliz {
namespace milvus {
......@@ -29,14 +30,15 @@ namespace cache {
template<typename key_t, typename value_t>
class LRU {
public:
public:
typedef typename std::pair<key_t, value_t> key_value_pair_t;
typedef typename std::list<key_value_pair_t>::iterator list_iterator_t;
typedef typename std::list<key_value_pair_t>::reverse_iterator reverse_list_iterator_t;
LRU(size_t max_size) : max_size_(max_size) {}
explicit LRU(size_t max_size) : max_size_(max_size) {
}
void put(const key_t& key, const value_t& value) {
void put(const key_t &key, const value_t &value) {
auto it = cache_items_map_.find(key);
cache_items_list_.push_front(key_value_pair_t(key, value));
if (it != cache_items_map_.end()) {
......@@ -53,7 +55,7 @@ public:
}
}
const value_t& get(const key_t& key) {
const value_t &get(const key_t &key) {
auto it = cache_items_map_.find(key);
if (it == cache_items_map_.end()) {
throw std::range_error("There is no such key in cache");
......@@ -63,7 +65,7 @@ public:
}
}
void erase(const key_t& key) {
void erase(const key_t &key) {
auto it = cache_items_map_.find(key);
if (it != cache_items_map_.end()) {
cache_items_list_.erase(it->second);
......@@ -71,7 +73,7 @@ public:
}
}
bool exists(const key_t& key) const {
bool exists(const key_t &key) const {
return cache_items_map_.find(key) != cache_items_map_.end();
}
......@@ -101,14 +103,14 @@ public:
cache_items_map_.clear();
}
private:
private:
std::list<key_value_pair_t> cache_items_list_;
std::unordered_map<key_t, list_iterator_t> cache_items_map_;
size_t max_size_;
list_iterator_t iter_;
};
} // cache
} // milvus
} // zilliz
} // namespace cache
} // namespace milvus
} // namespace zilliz
......@@ -76,7 +76,6 @@ main(int argc, char *argv[]) {
std::cout << "Initial log config from: " << log_config_file << std::endl;
break;
}
case 'p': {
char *pid_filename_ptr = strdup(optarg);
pid_filename = pid_filename_ptr;
......@@ -84,7 +83,6 @@ main(int argc, char *argv[]) {
std::cout << pid_filename << std::endl;
break;
}
case 'd':
start_daemonized = 1;
break;
......
......@@ -23,8 +23,8 @@
#include "src/ClientTest.h"
void print_help(const std::string &app_name);
void
print_help(const std::string &app_name);
int
main(int argc, char *argv[]) {
......@@ -56,8 +56,7 @@ main(int argc, char *argv[]) {
break;
}
case 'h':
default:
print_help(app_name);
default:print_help(app_name);
return EXIT_SUCCESS;
}
}
......@@ -77,4 +76,4 @@ print_help(const std::string &app_name) {
printf(" -p --port Server port, default 19530\n");
printf(" -h --help Print help information\n");
printf("\n");
}
\ No newline at end of file
}
......@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
#include "ClientTest.h"
#include "sdk/examples/grpcsimple/src/ClientTest.h"
#include "MilvusApi.h"
#include "cache/CpuCacheMgr.h"
......@@ -24,15 +24,17 @@
#include <chrono>
#include <thread>
#include <unistd.h>
using namespace milvus;
#include <memory>
#include <vector>
#include <utility>
//#define SET_VECTOR_IDS;
namespace {
std::string GetTableName();
const std::string&
GetTableName();
const std::string TABLE_NAME = GetTableName();
const char* TABLE_NAME = GetTableName().c_str();
constexpr int64_t TABLE_DIMENSION = 512;
constexpr int64_t TABLE_INDEX_FILE_SIZE = 1024;
constexpr int64_t BATCH_ROW_COUNT = 100000;
......@@ -44,26 +46,28 @@ constexpr int64_t SECONDS_EACH_HOUR = 3600;
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
void PrintTableSchema(const TableSchema& tb_schema) {
void
PrintTableSchema(const milvus::TableSchema &tb_schema) {
BLOCK_SPLITER
std::cout << "Table name: " << tb_schema.table_name << std::endl;
std::cout << "Table dimension: " << tb_schema.dimension << std::endl;
BLOCK_SPLITER
}
void PrintSearchResult(const std::vector<std::pair<int64_t, RowRecord>>& search_record_array,
const std::vector<TopKQueryResult>& topk_query_result_array) {
void
PrintSearchResult(const std::vector<std::pair<int64_t, milvus::RowRecord>> &search_record_array,
const std::vector<milvus::TopKQueryResult> &topk_query_result_array) {
BLOCK_SPLITER
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
int32_t index = 0;
for(auto& result : topk_query_result_array) {
for (auto &result : topk_query_result_array) {
auto search_id = search_record_array[index].first;
index++;
std::cout << "No." << std::to_string(index) << " vector " << std::to_string(search_id)
<< " top " << std::to_string(result.query_result_arrays.size())
<< " search result:" << std::endl;
for(auto& item : result.query_result_arrays) {
for (auto &item : result.query_result_arrays) {
std::cout << "\t" << std::to_string(item.id) << "\tdistance:" << std::to_string(item.distance);
std::cout << std::endl;
}
......@@ -72,80 +76,88 @@ void PrintSearchResult(const std::vector<std::pair<int64_t, RowRecord>>& search_
BLOCK_SPLITER
}
std::string CurrentTime() {
std::string
CurrentTime() {
time_t tt;
time( &tt );
tt = tt + 8*SECONDS_EACH_HOUR;
tm* t= gmtime( &tt );
time(&tt);
tt = tt + 8 * SECONDS_EACH_HOUR;
tm t;
gmtime_r(&tt, &t);
std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1)
+ "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour)
+ "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec);
std::string str = std::to_string(t.tm_year + 1900) + "_" + std::to_string(t.tm_mon + 1)
+ "_" + std::to_string(t.tm_mday) + "_" + std::to_string(t.tm_hour)
+ "_" + std::to_string(t.tm_min) + "_" + std::to_string(t.tm_sec);
return str;
}
std::string CurrentTmDate(int64_t offset_day = 0) {
std::string
CurrentTmDate(int64_t offset_day = 0) {
time_t tt;
time( &tt );
tt = tt + 8*SECONDS_EACH_HOUR;
tt = tt + 24*SECONDS_EACH_HOUR*offset_day;
tm* t= gmtime( &tt );
time(&tt);
tt = tt + 8 * SECONDS_EACH_HOUR;
tt = tt + 24 * SECONDS_EACH_HOUR * offset_day;
tm t;
gmtime_r(&tt, &t);
std::string str = std::to_string(t->tm_year + 1900) + "-" + std::to_string(t->tm_mon + 1)
+ "-" + std::to_string(t->tm_mday);
std::string str = std::to_string(t.tm_year + 1900) + "-" + std::to_string(t.tm_mon + 1)
+ "-" + std::to_string(t.tm_mday);
return str;
}
std::string GetTableName() {
static std::string s_id(CurrentTime());
return "tbl_" + s_id;
const std::string&
GetTableName() {
static std::string s_id("tbl_" + CurrentTime());
return s_id;
}
TableSchema BuildTableSchema() {
TableSchema tb_schema;
milvus::TableSchema
BuildTableSchema() {
milvus::TableSchema tb_schema;
tb_schema.table_name = TABLE_NAME;
tb_schema.dimension = TABLE_DIMENSION;
tb_schema.index_file_size = TABLE_INDEX_FILE_SIZE;
tb_schema.metric_type = MetricType::L2;
tb_schema.metric_type = milvus::MetricType::L2;
return tb_schema;
}
void BuildVectors(int64_t from, int64_t to,
std::vector<RowRecord>& vector_record_array) {
if(to <= from){
void
BuildVectors(int64_t from, int64_t to,
std::vector<milvus::RowRecord> &vector_record_array) {
if (to <= from) {
return;
}
vector_record_array.clear();
for (int64_t k = from; k < to; k++) {
RowRecord record;
milvus::RowRecord record;
record.data.resize(TABLE_DIMENSION);
for(int64_t i = 0; i < TABLE_DIMENSION; i++) {
record.data[i] = (float)(k%(i+1));
for (int64_t i = 0; i < TABLE_DIMENSION; i++) {
record.data[i] = (float) (k % (i + 1));
}
vector_record_array.emplace_back(record);
}
}
void Sleep(int seconds) {
void
Sleep(int seconds) {
std::cout << "Waiting " << seconds << " seconds ..." << std::endl;
sleep(seconds);
}
class TimeRecorder {
public:
explicit TimeRecorder(const std::string& title)
explicit TimeRecorder(const std::string &title)
: title_(title) {
start_ = std::chrono::system_clock::now();
}
~TimeRecorder() {
std::chrono::system_clock::time_point end = std::chrono::system_clock::now();
long span = (std::chrono::duration_cast<std::chrono::milliseconds> (end - start_)).count();
int64_t span = (std::chrono::duration_cast<std::chrono::milliseconds>(end - start_)).count();
std::cout << title_ << " totally cost: " << span << " ms" << std::endl;
}
......@@ -154,14 +166,15 @@ class TimeRecorder {
std::chrono::system_clock::time_point start_;
};
void CheckResult(const std::vector<std::pair<int64_t, RowRecord>>& search_record_array,
const std::vector<TopKQueryResult>& topk_query_result_array) {
void
CheckResult(const std::vector<std::pair<int64_t, milvus::RowRecord>> &search_record_array,
const std::vector<milvus::TopKQueryResult> &topk_query_result_array) {
BLOCK_SPLITER
int64_t index = 0;
for(auto& result : topk_query_result_array) {
for (auto &result : topk_query_result_array) {
auto result_id = result.query_result_arrays[0].id;
auto search_id = search_record_array[index++].first;
if(result_id != search_id) {
if (result_id != search_id) {
std::cout << "The top 1 result is wrong: " << result_id
<< " vs. " << search_id << std::endl;
} else {
......@@ -171,42 +184,45 @@ void CheckResult(const std::vector<std::pair<int64_t, RowRecord>>& search_record
BLOCK_SPLITER
}
void DoSearch(std::shared_ptr<Connection> conn,
const std::vector<std::pair<int64_t, RowRecord>>& search_record_array,
const std::string& phase_name) {
std::vector<Range> query_range_array;
Range rg;
void
DoSearch(std::shared_ptr<milvus::Connection> conn,
const std::vector<std::pair<int64_t, milvus::RowRecord>> &search_record_array,
const std::string &phase_name) {
std::vector<milvus::Range> query_range_array;
milvus::Range rg;
rg.start_value = CurrentTmDate();
rg.end_value = CurrentTmDate(1);
query_range_array.emplace_back(rg);
std::vector<RowRecord> record_array;
for(auto& pair : search_record_array) {
std::vector<milvus::RowRecord> record_array;
for (auto &pair : search_record_array) {
record_array.push_back(pair.second);
}
auto start = std::chrono::high_resolution_clock::now();
std::vector<TopKQueryResult> topk_query_result_array;
std::vector<milvus::TopKQueryResult> topk_query_result_array;
{
TimeRecorder rc(phase_name);
Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result_array);
milvus::Status stat =
conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result_array);
std::cout << "SearchVector function call status: " << stat.message() << std::endl;
}
auto finish = std::chrono::high_resolution_clock::now();
std::cout << "SEARCHVECTOR COST: " << std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
std::cout << "SEARCHVECTOR COST: "
<< std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
PrintSearchResult(search_record_array, topk_query_result_array);
CheckResult(search_record_array, topk_query_result_array);
}
}
} // namespace
void
ClientTest::Test(const std::string& address, const std::string& port) {
std::shared_ptr<Connection> conn = Connection::Create();
ClientTest::Test(const std::string &address, const std::string &port) {
std::shared_ptr<milvus::Connection> conn = milvus::Connection::Create();
{//connect server
ConnectParam param = {address, port};
Status stat = conn->Connect(param);
milvus::ConnectParam param = {address, port};
milvus::Status stat = conn->Connect(param);
std::cout << "Connect function call status: " << stat.message() << std::endl;
}
......@@ -222,10 +238,10 @@ ClientTest::Test(const std::string& address, const std::string& port) {
{
std::vector<std::string> tables;
Status stat = conn->ShowTables(tables);
milvus::Status stat = conn->ShowTables(tables);
std::cout << "ShowTables function call status: " << stat.message() << std::endl;
std::cout << "All tables: " << std::endl;
for(auto& table : tables) {
for (auto &table : tables) {
int64_t row_count = 0;
// conn->DropTable(table);
stat = conn->CountTable(table, row_count);
......@@ -234,28 +250,28 @@ ClientTest::Test(const std::string& address, const std::string& port) {
}
{//create table
TableSchema tb_schema = BuildTableSchema();
Status stat = conn->CreateTable(tb_schema);
milvus::TableSchema tb_schema = BuildTableSchema();
milvus::Status stat = conn->CreateTable(tb_schema);
std::cout << "CreateTable function call status: " << stat.message() << std::endl;
PrintTableSchema(tb_schema);
bool has_table = conn->HasTable(tb_schema.table_name);
if(has_table) {
if (has_table) {
std::cout << "Table is created" << std::endl;
}
}
{//describe table
TableSchema tb_schema;
Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
milvus::TableSchema tb_schema;
milvus::Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
std::cout << "DescribeTable function call status: " << stat.message() << std::endl;
PrintTableSchema(tb_schema);
}
std::vector<std::pair<int64_t, RowRecord>> search_record_array;
std::vector<std::pair<int64_t, milvus::RowRecord>> search_record_array;
{//insert vectors
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors
std::vector<RowRecord> record_array;
std::vector<milvus::RowRecord> record_array;
int64_t begin_index = i * BATCH_ROW_COUNT;
BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array);
......@@ -268,21 +284,21 @@ ClientTest::Test(const std::string& address, const std::string& port) {
std::vector<int64_t> record_ids;
//generate user defined ids
for(int k = 0; k < BATCH_ROW_COUNT; k++) {
record_ids.push_back(i*BATCH_ROW_COUNT+k);
for (int k = 0; k < BATCH_ROW_COUNT; k++) {
record_ids.push_back(i * BATCH_ROW_COUNT + k);
}
auto start = std::chrono::high_resolution_clock::now();
Status stat = conn->Insert(TABLE_NAME, record_array, record_ids);
milvus::Status stat = conn->Insert(TABLE_NAME, record_array, record_ids);
auto finish = std::chrono::high_resolution_clock::now();
std::cout << "InsertVector cost: " << std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
std::cout << "InsertVector cost: "
<< std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
std::cout << "InsertVector function call status: " << stat.message() << std::endl;
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
if(search_record_array.size() < NQ) {
if (search_record_array.size() < NQ) {
search_record_array.push_back(
std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET]));
}
......@@ -293,27 +309,27 @@ ClientTest::Test(const std::string& address, const std::string& port) {
Sleep(2);
int64_t row_count = 0;
Status stat = conn->CountTable(TABLE_NAME, row_count);
milvus::Status stat = conn->CountTable(TABLE_NAME, row_count);
std::cout << TABLE_NAME << "(" << row_count << " rows)" << std::endl;
// DoSearch(conn, search_record_array, "Search without index");
}
{//wait unit build index finish
std::cout << "Wait until create all index done" << std::endl;
IndexParam index;
milvus::IndexParam index;
index.table_name = TABLE_NAME;
index.index_type = IndexType::gpu_ivfsq8;
index.index_type = milvus::IndexType::gpu_ivfsq8;
index.nlist = 16384;
Status stat = conn->CreateIndex(index);
milvus::Status stat = conn->CreateIndex(index);
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
IndexParam index2;
milvus::IndexParam index2;
stat = conn->DescribeIndex(TABLE_NAME, index2);
std::cout << "DescribeIndex function call status: " << stat.message() << std::endl;
}
{//preload table
Status stat = conn->PreloadTable(TABLE_NAME);
milvus::Status stat = conn->PreloadTable(TABLE_NAME);
std::cout << "PreloadTable function call status: " << stat.message() << std::endl;
}
......@@ -325,7 +341,7 @@ ClientTest::Test(const std::string& address, const std::string& port) {
}
{//delete index
Status stat = conn->DropIndex(TABLE_NAME);
milvus::Status stat = conn->DropIndex(TABLE_NAME);
std::cout << "DropIndex function call status: " << stat.message() << std::endl;
int64_t row_count = 0;
......@@ -334,11 +350,11 @@ ClientTest::Test(const std::string& address, const std::string& port) {
}
{//delete by range
Range rg;
milvus::Range rg;
rg.start_value = CurrentTmDate(-2);
rg.end_value = CurrentTmDate(-3);
Status stat = conn->DeleteByRange(rg, TABLE_NAME);
milvus::Status stat = conn->DeleteByRange(rg, TABLE_NAME);
std::cout << "DeleteByRange function call status: " << stat.message() << std::endl;
}
......@@ -351,7 +367,7 @@ ClientTest::Test(const std::string& address, const std::string& port) {
std::string status = conn->ServerStatus();
std::cout << "Server status before disconnect: " << status << std::endl;
}
Connection::Destroy(conn);
milvus::Connection::Destroy(conn);
{//server status
std::string status = conn->ServerStatus();
std::cout << "Server status after disconnect: " << status << std::endl;
......
......@@ -20,6 +20,6 @@
#include <string>
class ClientTest {
public:
void Test(const std::string& address, const std::string& port);
};
\ No newline at end of file
public:
void Test(const std::string &address, const std::string &port);
};
......@@ -15,16 +15,20 @@
// specific language governing permissions and limitations
// under the License.
#include "ClientProxy.h"
#include "version.h"
#include "milvus.grpc.pb.h"
#include "sdk/grpc/ClientProxy.h"
#include "../../../version.h"
#include "grpc/gen-milvus/milvus.grpc.pb.h"
#include <memory>
#include <vector>
#include <string>
//#define GRPC_MULTIPLE_THREAD;
namespace milvus {
bool
UriCheck(const std::string &uri) {
size_t index = uri.find_first_of(':', 0);
size_t index = uri.find_first_of(':', 0);
if (index == std::string::npos) {
return false;
} else {
......@@ -79,7 +83,7 @@ ClientProxy::Disconnect() {
connected_ = false;
channel_.reset();
return status;
}catch (std::exception &ex) {
} catch (std::exception &ex) {
return Status(StatusCode::UnknownError, "failed to disconnect: " + std::string(ex.what()));
}
}
......@@ -96,7 +100,7 @@ ClientProxy::CreateTable(const TableSchema &param) {
schema.set_table_name(param.table_name);
schema.set_dimension(param.dimension);
schema.set_index_file_size(param.index_file_size);
schema.set_metric_type((int32_t)param.metric_type);
schema.set_metric_type((int32_t) param.metric_type);
return client_ptr_->CreateTable(schema);
} catch (std::exception &ex) {
......@@ -127,13 +131,12 @@ ClientProxy::DropTable(const std::string &table_name) {
Status
ClientProxy::CreateIndex(const IndexParam &index_param) {
try {
//TODO:add index params
//TODO: add index params
::milvus::grpc::IndexParam grpc_index_param;
grpc_index_param.set_table_name(index_param.table_name);
grpc_index_param.mutable_index()->set_index_type((int32_t)index_param.index_type);
grpc_index_param.mutable_index()->set_index_type((int32_t) index_param.index_type);
grpc_index_param.mutable_index()->set_nlist(index_param.nlist);
return client_ptr_->CreateIndex(grpc_index_param);
} catch (std::exception &ex) {
return Status(StatusCode::UnknownError, "failed to build index: " + std::string(ex.what()));
}
......@@ -141,8 +144,8 @@ ClientProxy::CreateIndex(const IndexParam &index_param) {
Status
ClientProxy::Insert(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
Status status = Status::OK();
try {
////////////////////////////////////////////////////////////////////////////
......@@ -181,7 +184,9 @@ ClientProxy::Insert(const std::string &table_name,
}
std::for_each(threads.begin(), threads.end(), std::mem_fn(&std::thread::join));
auto finish = std::chrono::high_resolution_clock::now();
std::cout << "InsertVector cost: " << std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
std::cout <<
"InsertVector cost: " << std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count()
<< "s\n";
std::cout << "*****************************************************\n";
for (size_t i = 0; i < thread_count; i++) {
......@@ -213,9 +218,7 @@ ClientProxy::Insert(const std::string &table_name,
id_array.push_back(vector_ids.vector_id_array(i));
}
}
#endif
} catch (std::exception &ex) {
return Status(StatusCode::UnknownError, "fail to add vector: " + std::string(ex.what()));
}
......@@ -225,11 +228,11 @@ ClientProxy::Insert(const std::string &table_name,
Status
ClientProxy::Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
try {
//step 1: convert vectors data
::milvus::grpc::SearchParam search_param;
......@@ -267,11 +270,9 @@ ClientProxy::Search(const std::string &table_name,
topk_query_result_array.emplace_back(result);
}
return status;
} catch (std::exception &ex) {
return Status(StatusCode::UnknownError, "fail to search vectors: " + std::string(ex.what()));
}
}
Status
......@@ -284,13 +285,12 @@ ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_sch
table_schema.table_name = grpc_schema.table_name();
table_schema.dimension = grpc_schema.dimension();
table_schema.index_file_size = grpc_schema.index_file_size();
table_schema.metric_type = (MetricType)grpc_schema.metric_type();
table_schema.metric_type = (MetricType) grpc_schema.metric_type();
return status;
} catch (std::exception &ex) {
return Status(StatusCode::UnknownError, "fail to describe table: " + std::string(ex.what()));
}
}
Status
......@@ -316,7 +316,6 @@ ClientProxy::ShowTables(std::vector<std::string> &table_array) {
table_array[i] = table_name_list.table_names(i);
}
return status;
} catch (std::exception &ex) {
return Status(StatusCode::UnknownError, "fail to show tables: " + std::string(ex.what()));
}
......@@ -396,11 +395,10 @@ ClientProxy::DescribeIndex(const std::string &table_name, IndexParam &index_para
grpc_table_name.set_table_name(table_name);
::milvus::grpc::IndexParam grpc_index_param;
Status status = client_ptr_->DescribeIndex(grpc_table_name, grpc_index_param);
index_param.index_type = (IndexType)(grpc_index_param.mutable_index()->index_type());
index_param.index_type = (IndexType) (grpc_index_param.mutable_index()->index_type());
index_param.nlist = grpc_index_param.mutable_index()->nlist();
return status;
} catch (std::exception &ex) {
return Status(StatusCode::UnknownError, "fail to describe index: " + std::string(ex.what()));
}
......@@ -418,4 +416,4 @@ ClientProxy::DropIndex(const std::string &table_name) const {
}
}
}
} // namespace milvus
......@@ -20,88 +20,92 @@
#include "MilvusApi.h"
#include "GrpcClient.h"
#include <vector>
#include <string>
#include <memory>
namespace milvus {
class ClientProxy : public Connection {
public:
public:
// Implementations of the Connection interface
virtual Status
Status
Connect(const ConnectParam &param) override;
virtual Status
Status
Connect(const std::string &uri) override;
virtual Status
Status
Connected() const override;
virtual Status
Status
Disconnect() override;
virtual Status
Status
CreateTable(const TableSchema &param) override;
virtual bool
bool
HasTable(const std::string &table_name) override;
virtual Status
Status
DropTable(const std::string &table_name) override;
virtual Status
Status
CreateIndex(const IndexParam &index_param) override;
virtual Status
Status
Insert(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status
Status
Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status
Status
DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status
Status
CountTable(const std::string &table_name, int64_t &row_count) override;
virtual Status
Status
ShowTables(std::vector<std::string> &table_array) override;
virtual std::string
std::string
ClientVersion() const override;
virtual std::string
std::string
ServerVersion() const override;
virtual std::string
std::string
ServerStatus() const override;
virtual std::string
std::string
DumpTaskTables() const override;
virtual Status
Status
DeleteByRange(Range &range,
const std::string &table_name) override;
virtual Status
Status
PreloadTable(const std::string &table_name) const override;
virtual Status
Status
DescribeIndex(const std::string &table_name, IndexParam &index_param) const override;
virtual Status
Status
DropIndex(const std::string &table_name) const override;
private:
private:
std::shared_ptr<::grpc::Channel> channel_;
private:
private:
std::shared_ptr<GrpcClient> client_ptr_;
bool connected_ = false;
};
}
} // namespace milvus
......@@ -15,13 +15,17 @@
// specific language governing permissions and limitations
// under the License.
#include "sdk/grpc/GrpcClient.h"
#include <grpc/grpc.h>
#include <grpcpp/channel.h>
#include <grpcpp/client_context.h>
#include <grpcpp/create_channel.h>
#include <grpcpp/security/credentials.h>
#include "GrpcClient.h"
#include <vector>
#include <string>
#include <memory>
using grpc::Channel;
using grpc::ClientContext;
......@@ -31,15 +35,14 @@ using grpc::ClientWriter;
using grpc::Status;
namespace milvus {
GrpcClient::GrpcClient(std::shared_ptr<::grpc::Channel>& channel)
: stub_(::milvus::grpc::MilvusService::NewStub(channel)) {
GrpcClient::GrpcClient(std::shared_ptr<::grpc::Channel> &channel)
: stub_(::milvus::grpc::MilvusService::NewStub(channel)) {
}
GrpcClient::~GrpcClient() = default;
Status
GrpcClient::CreateTable(const ::milvus::grpc::TableSchema& table_schema) {
GrpcClient::CreateTable(const ::milvus::grpc::TableSchema &table_schema) {
ClientContext context;
grpc::Status response;
::grpc::Status grpc_status = stub_->CreateTable(&context, table_schema, &response);
......@@ -57,8 +60,8 @@ GrpcClient::CreateTable(const ::milvus::grpc::TableSchema& table_schema) {
}
bool
GrpcClient::HasTable(const ::milvus::grpc::TableName& table_name,
Status& status) {
GrpcClient::HasTable(const ::milvus::grpc::TableName &table_name,
Status &status) {
ClientContext context;
::milvus::grpc::BoolReply response;
::grpc::Status grpc_status = stub_->HasTable(&context, table_name, &response);
......@@ -76,7 +79,7 @@ GrpcClient::HasTable(const ::milvus::grpc::TableName& table_name,
}
Status
GrpcClient::DropTable(const ::milvus::grpc::TableName& table_name) {
GrpcClient::DropTable(const ::milvus::grpc::TableName &table_name) {
ClientContext context;
grpc::Status response;
::grpc::Status grpc_status = stub_->DropTable(&context, table_name, &response);
......@@ -94,7 +97,7 @@ GrpcClient::DropTable(const ::milvus::grpc::TableName& table_name) {
}
Status
GrpcClient::CreateIndex(const ::milvus::grpc::IndexParam& index_param) {
GrpcClient::CreateIndex(const ::milvus::grpc::IndexParam &index_param) {
ClientContext context;
grpc::Status response;
::grpc::Status grpc_status = stub_->CreateIndex(&context, index_param, &response);
......@@ -112,9 +115,9 @@ GrpcClient::CreateIndex(const ::milvus::grpc::IndexParam& index_param) {
}
void
GrpcClient::Insert(::milvus::grpc::VectorIds& vector_ids,
const ::milvus::grpc::InsertParam& insert_param,
Status& status) {
GrpcClient::Insert(::milvus::grpc::VectorIds &vector_ids,
const ::milvus::grpc::InsertParam &insert_param,
Status &status) {
ClientContext context;
::grpc::Status grpc_status = stub_->Insert(&context, insert_param, &vector_ids);
......@@ -133,7 +136,7 @@ GrpcClient::Insert(::milvus::grpc::VectorIds& vector_ids,
}
Status
GrpcClient::Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list,
GrpcClient::Search(::milvus::grpc::TopKQueryResultList &topk_query_result_list,
const ::milvus::grpc::SearchParam &search_param) {
::milvus::grpc::TopKQueryResult query_result;
ClientContext context;
......@@ -154,8 +157,8 @@ GrpcClient::Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list,
}
Status
GrpcClient::DescribeTable(::milvus::grpc::TableSchema& grpc_schema,
const std::string& table_name) {
GrpcClient::DescribeTable(::milvus::grpc::TableSchema &grpc_schema,
const std::string &table_name) {
ClientContext context;
::milvus::grpc::TableName grpc_tablename;
grpc_tablename.set_table_name(table_name);
......@@ -170,14 +173,14 @@ GrpcClient::DescribeTable(::milvus::grpc::TableSchema& grpc_schema,
if (grpc_schema.status().error_code() != grpc::SUCCESS) {
std::cerr << grpc_schema.status().reason() << std::endl;
return Status(StatusCode::ServerFailed,
grpc_schema.status().reason());
grpc_schema.status().reason());
}
return Status::OK();
}
int64_t
GrpcClient::CountTable(const std::string& table_name, Status& status) {
GrpcClient::CountTable(const std::string &table_name, Status &status) {
ClientContext context;
::milvus::grpc::TableRowCount response;
::milvus::grpc::TableName grpc_tablename;
......@@ -186,7 +189,7 @@ GrpcClient::CountTable(const std::string& table_name, Status& status) {
if (!grpc_status.ok()) {
std::cerr << "DescribeTable rpc failed!" << std::endl;
status = Status(StatusCode::RPCFailed, grpc_status.error_message());
status = Status(StatusCode::RPCFailed, grpc_status.error_message());
return -1;
}
......@@ -223,7 +226,7 @@ GrpcClient::ShowTables(milvus::grpc::TableNameList &table_name_list) {
Status
GrpcClient::Cmd(std::string &result,
const std::string& cmd) {
const std::string &cmd) {
ClientContext context;
::milvus::grpc::StringReply response;
::milvus::grpc::Command command;
......@@ -321,4 +324,4 @@ GrpcClient::DropIndex(grpc::TableName &table_name) {
return Status::OK();
}
}
\ No newline at end of file
} // namespace milvus
......@@ -16,6 +16,11 @@
// under the License.
#pragma once
#include "MilvusApi.h"
#include "grpc/gen-milvus/milvus.grpc.pb.h"
//#include "grpc/gen-status/status.grpc.pb.h"
#include <chrono>
#include <iostream>
#include <memory>
......@@ -28,55 +33,48 @@
#include <grpcpp/client_context.h>
#include <grpcpp/create_channel.h>
#include <grpcpp/security/credentials.h>
#include "MilvusApi.h"
#include "milvus.grpc.pb.h"
//#include "status.grpc.pb.h"
#include <memory>
namespace milvus {
class GrpcClient {
public:
explicit
GrpcClient(std::shared_ptr<::grpc::Channel>& channel);
public:
explicit GrpcClient(std::shared_ptr<::grpc::Channel> &channel);
virtual
~GrpcClient();
Status
CreateTable(const grpc::TableSchema& table_schema);
CreateTable(const grpc::TableSchema &table_schema);
bool
HasTable(const grpc::TableName& table_name, Status& status);
HasTable(const grpc::TableName &table_name, Status &status);
Status
DropTable(const grpc::TableName& table_name);
DropTable(const grpc::TableName &table_name);
Status
CreateIndex(const grpc::IndexParam& index_param);
CreateIndex(const grpc::IndexParam &index_param);
void
Insert(grpc::VectorIds& vector_ids,
const grpc::InsertParam& insert_param,
Status& status);
Insert(grpc::VectorIds &vector_ids,
const grpc::InsertParam &insert_param,
Status &status);
Status
Search(::milvus::grpc::TopKQueryResultList& topk_query_result_list,
Search(::milvus::grpc::TopKQueryResultList &topk_query_result_list,
const grpc::SearchParam &search_param);
Status
DescribeTable(grpc::TableSchema& grpc_schema,
const std::string& table_name);
DescribeTable(grpc::TableSchema &grpc_schema,
const std::string &table_name);
int64_t
CountTable(const std::string& table_name, Status& status);
CountTable(const std::string &table_name, Status &status);
Status
ShowTables(milvus::grpc::TableNameList &table_name_list);
Status
Cmd(std::string &result, const std::string& cmd);
Cmd(std::string &result, const std::string &cmd);
Status
DeleteByRange(grpc::DeleteByRangeParam &delete_by_range_param);
......@@ -93,8 +91,8 @@ public:
Status
Disconnect();
private:
private:
std::unique_ptr<grpc::MilvusService::Stub> stub_;
};
}
} // namespace milvus
......@@ -28,7 +28,6 @@
*/
namespace milvus {
/**
* @brief Index Type
*/
......@@ -108,7 +107,6 @@ struct IndexParam {
*/
class Connection {
public:
/**
* @brief CreateConnection
*
......@@ -131,7 +129,7 @@ class Connection {
*/
static Status
Destroy(std::shared_ptr<Connection>& connection_ptr);
Destroy(std::shared_ptr<Connection> &connection_ptr);
/**
* @brief Connect
......@@ -180,7 +178,6 @@ class Connection {
virtual Status
Disconnect() = 0;
/**
* @brief Create table method
*
......@@ -193,7 +190,6 @@ class Connection {
virtual Status
CreateTable(const TableSchema &param) = 0;
/**
* @brief Test table existence method
*
......@@ -206,7 +202,6 @@ class Connection {
virtual bool
HasTable(const std::string &table_name) = 0;
/**
* @brief Delete table method
*
......@@ -248,8 +243,8 @@ class Connection {
*/
virtual Status
Insert(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) = 0;
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) = 0;
/**
* @brief Search vector
......@@ -266,11 +261,11 @@ class Connection {
*/
virtual Status
Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) = 0;
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) = 0;
/**
* @brief Show table description
......@@ -297,7 +292,7 @@ class Connection {
*/
virtual Status
CountTable(const std::string &table_name,
int64_t &row_count) = 0;
int64_t &row_count) = 0;
/**
* @brief Show all tables in database
......@@ -395,4 +390,4 @@ class Connection {
DropIndex(const std::string &table_name) const = 0;
};
}
\ No newline at end of file
} // namespace milvus
......@@ -29,6 +29,7 @@ namespace milvus {
*/
enum class StatusCode {
OK = 0,
// system error section
UnknownError = 1,
NotSupported,
......@@ -44,7 +45,7 @@ enum class StatusCode {
* @brief Status for SDK interface return
*/
class Status {
public:
public:
Status(StatusCode code, const std::string &msg);
Status();
~Status();
......@@ -60,28 +61,32 @@ public:
operator=(Status &&s);
static Status
OK() { return Status(); }
OK() {
return Status();
}
bool
ok() const { return state_ == nullptr || code() == StatusCode::OK; }
ok() const {
return state_ == nullptr || code() == StatusCode::OK;
}
StatusCode
code() const {
return (state_ == nullptr) ? StatusCode::OK : *(StatusCode*)(state_);
return (state_ == nullptr) ? StatusCode::OK : *(StatusCode *) (state_);
}
std::string
message() const;
private:
private:
inline void
CopyFrom(const Status &s);
inline void
MoveFrom(Status &s);
private:
private:
const char *state_ = nullptr;
}; // Status
} //Milvus
\ No newline at end of file
} // namespace milvus
......@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
#include "ConnectionImpl.h"
#include "sdk/interface/ConnectionImpl.h"
namespace milvus {
......@@ -25,7 +25,7 @@ Connection::Create() {
}
Status
Connection::Destroy(std::shared_ptr<milvus::Connection>& connection_ptr) {
Connection::Destroy(std::shared_ptr<milvus::Connection> &connection_ptr) {
if (connection_ptr != nullptr) {
return connection_ptr->Disconnect();
}
......@@ -84,19 +84,18 @@ ConnectionImpl::CreateIndex(const IndexParam &index_param) {
Status
ConnectionImpl::Insert(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
return client_proxy_->Insert(table_name, record_array, id_array);
}
Status
ConnectionImpl::Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) {
return client_proxy_->Search(table_name, query_record_array, query_range_array, topk,
nprobe, topk_query_result_array);
}
......@@ -133,7 +132,7 @@ ConnectionImpl::DumpTaskTables() const {
Status
ConnectionImpl::DeleteByRange(Range &range,
const std::string &table_name) {
const std::string &table_name) {
return client_proxy_->DeleteByRange(range, table_name);
}
......@@ -143,7 +142,7 @@ ConnectionImpl::PreloadTable(const std::string &table_name) const {
}
Status
ConnectionImpl::DescribeIndex(const std::string &table_name, IndexParam& index_param) const {
ConnectionImpl::DescribeIndex(const std::string &table_name, IndexParam &index_param) const {
return client_proxy_->DescribeIndex(table_name, index_param);
}
......@@ -152,4 +151,4 @@ ConnectionImpl::DropIndex(const std::string &table_name) const {
return client_proxy_->DropIndex(table_name);
}
}
} // namespace milvus
......@@ -18,88 +18,92 @@
#pragma once
#include "MilvusApi.h"
#include "src/sdk/grpc/ClientProxy.h"
#include "sdk/grpc/ClientProxy.h"
#include <vector>
#include <memory>
#include <string>
namespace milvus {
class ConnectionImpl : public Connection {
public:
public:
ConnectionImpl();
// Implementations of the Connection interface
virtual Status
Status
Connect(const ConnectParam &param) override;
virtual Status
Status
Connect(const std::string &uri) override;
virtual Status
Status
Connected() const override;
virtual Status
Status
Disconnect() override;
virtual Status
Status
CreateTable(const TableSchema &param) override;
virtual
bool HasTable(const std::string &table_name) override;
bool
HasTable(const std::string &table_name) override;
virtual Status
Status
DropTable(const std::string &table_name) override;
virtual Status
Status
CreateIndex(const IndexParam &index_param) override;
virtual Status
Status
Insert(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status
Status
Search(const std::string &table_name,
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
const std::vector<RowRecord> &query_record_array,
const std::vector<Range> &query_range_array,
int64_t topk,
int64_t nprobe,
std::vector<TopKQueryResult> &topk_query_result_array) override;
virtual Status
Status
DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status
Status
CountTable(const std::string &table_name, int64_t &row_count) override;
virtual Status
Status
ShowTables(std::vector<std::string> &table_array) override;
virtual std::string
std::string
ClientVersion() const override;
virtual std::string
std::string
ServerVersion() const override;
virtual std::string
std::string
ServerStatus() const override;
virtual std::string
std::string
DumpTaskTables() const override;
virtual Status
Status
DeleteByRange(Range &range,
const std::string &table_name) override;
virtual Status
Status
PreloadTable(const std::string &table_name) const override;
virtual Status
DescribeIndex(const std::string &table_name, IndexParam& index_param) const override;
Status
DescribeIndex(const std::string &table_name, IndexParam &index_param) const override;
virtual Status
Status
DropIndex(const std::string &table_name) const override;
private:
private:
std::shared_ptr<ClientProxy> client_proxy_;
};
}
} // namespace milvus
......@@ -23,12 +23,12 @@ namespace milvus {
constexpr int CODE_WIDTH = sizeof(StatusCode);
Status::Status(StatusCode code, const std::string& msg) {
Status::Status(StatusCode code, const std::string &msg) {
//4 bytes store code
//4 bytes store message length
//the left bytes store message string
const uint32_t length = (uint32_t)msg.size();
char* result = new char[length + sizeof(length) + CODE_WIDTH];
const uint32_t length = (uint32_t) msg.size();
char *result = new char[length + sizeof(length) + CODE_WIDTH];
memcpy(result, &code, CODE_WIDTH);
memcpy(result + CODE_WIDTH, &length, sizeof(length));
memcpy(result + sizeof(length) + CODE_WIDTH, msg.data(), length);
......@@ -37,8 +37,7 @@ Status::Status(StatusCode code, const std::string& msg) {
}
Status::Status()
: state_(nullptr) {
: state_(nullptr) {
}
Status::~Status() {
......@@ -46,22 +45,22 @@ Status::~Status() {
}
Status::Status(const Status &s)
: state_(nullptr) {
: state_(nullptr) {
CopyFrom(s);
}
Status&
Status &
Status::operator=(const Status &s) {
CopyFrom(s);
return *this;
}
Status::Status(Status &&s)
: state_(nullptr) {
: state_(nullptr) {
MoveFrom(s);
}
Status&
Status &
Status::operator=(Status &&s) {
MoveFrom(s);
return *this;
......@@ -71,7 +70,7 @@ void
Status::CopyFrom(const Status &s) {
delete state_;
state_ = nullptr;
if(s.state_ == nullptr) {
if (s.state_ == nullptr) {
return;
}
......@@ -79,7 +78,7 @@ Status::CopyFrom(const Status &s) {
memcpy(&length, s.state_ + CODE_WIDTH, sizeof(length));
int buff_len = length + sizeof(length) + CODE_WIDTH;
state_ = new char[buff_len];
memcpy((void*)state_, (void*)s.state_, buff_len);
memcpy((void *) state_, (void *) s.state_, buff_len);
}
void
......@@ -98,12 +97,13 @@ Status::message() const {
std::string msg;
uint32_t length = 0;
memcpy(&length, state_ + CODE_WIDTH, sizeof(length));
if(length > 0) {
if (length > 0) {
msg.append(state_ + sizeof(length) + CODE_WIDTH, length);
}
return msg;
}
}
} // namespace milvus
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册