提交 0a8edff1 编写于 作者: K kun yu

merge brach-0.4.0


Former-commit-id: c752d75ebd02e18c7e5fc1fb21edb32d1c381aff
上级 45bb9d67
......@@ -185,153 +185,153 @@ namespace {
void
ClientTest::Test(const std::string& address, const std::string& port) {
// std::shared_ptr<Connection> conn = Connection::Create();
//
// {//connect server
// ConnectParam param = {address, port};
// Status stat = conn->Connect(param);
// std::cout << "Connect function call status: " << stat.ToString() << std::endl;
// }
//
// {//server version
// std::string version = conn->ServerVersion();
// std::cout << "Server version: " << version << std::endl;
// }
//
// {//sdk version
// std::string version = conn->ClientVersion();
// std::cout << "SDK version: " << version << std::endl;
// }
//
// {
// std::vector<std::string> tables;
// Status stat = conn->ShowTables(tables);
// std::cout << "ShowTables function call status: " << stat.ToString() << std::endl;
// std::cout << "All tables: " << std::endl;
// for(auto& table : tables) {
// int64_t row_count = 0;
// stat = conn->GetTableRowCount(table, row_count);
// std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl;
// }
// }
//
// {//create table
// TableSchema tb_schema = BuildTableSchema();
// Status stat = conn->CreateTable(tb_schema);
// std::cout << "CreateTable function call status: " << stat.ToString() << std::endl;
// PrintTableSchema(tb_schema);
//
// bool has_table = conn->HasTable(tb_schema.table_name);
// if(has_table) {
// std::cout << "Table is created" << std::endl;
// }
// }
//
// {//describe table
// TableSchema tb_schema;
// Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
// std::cout << "DescribeTable function call status: " << stat.ToString() << std::endl;
// PrintTableSchema(tb_schema);
// }
//
// Connection::Destroy(conn);
std::shared_ptr<Connection> conn = Connection::Create();
pid_t pid;
for (int i = 0; i < 5; ++i) {
pid = fork();
if (pid == 0 || pid == -1) {
break;
}
{//connect server
ConnectParam param = {address, port};
Status stat = conn->Connect(param);
std::cout << "Connect function call status: " << stat.ToString() << std::endl;
}
if (pid == -1) {
std::cout << "fail to fork!\n";
exit(1);
} else if (pid == 0) {
std::shared_ptr<Connection> conn = Connection::Create();
{//connect server
ConnectParam param = {address, port};
Status stat = conn->Connect(param);
std::cout << "Connect function call status: " << stat.ToString() << std::endl;
}
{//server version
std::string version = conn->ServerVersion();
std::cout << "Server version: " << version << std::endl;
}
Connection::Destroy(conn);
exit(0);
} else {
std::shared_ptr<Connection> conn = Connection::Create();
{//connect server
ConnectParam param = {address, port};
Status stat = conn->Connect(param);
std::cout << "Connect function call status: " << stat.ToString() << std::endl;
{//server version
std::string version = conn->ServerVersion();
std::cout << "Server version: " << version << std::endl;
}
{//sdk version
std::string version = conn->ClientVersion();
std::cout << "SDK version: " << version << std::endl;
}
{
std::vector<std::string> tables;
Status stat = conn->ShowTables(tables);
std::cout << "ShowTables function call status: " << stat.ToString() << std::endl;
std::cout << "All tables: " << std::endl;
for(auto& table : tables) {
int64_t row_count = 0;
stat = conn->GetTableRowCount(table, row_count);
std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl;
}
}
{//server version
std::string version = conn->ServerVersion();
std::cout << "Server version: " << version << std::endl;
{//create table
TableSchema tb_schema = BuildTableSchema();
Status stat = conn->CreateTable(tb_schema);
std::cout << "CreateTable function call status: " << stat.ToString() << std::endl;
PrintTableSchema(tb_schema);
bool has_table = conn->HasTable(tb_schema.table_name);
if(has_table) {
std::cout << "Table is created" << std::endl;
}
Connection::Destroy(conn);
std::cout << "in main process\n";
exit(0);
}
// std::vector<std::pair<int64_t, RowRecord>> search_record_array;
// {//insert vectors
// for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors
// std::vector<RowRecord> record_array;
// int64_t begin_index = i * BATCH_ROW_COUNT;
// BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array);
// std::vector<int64_t> record_ids;
//
// auto start = std::chrono::high_resolution_clock::now();
//
// Status stat = conn->InsertVector(TABLE_NAME, record_array, record_ids);
// auto finish = std::chrono::high_resolution_clock::now();
// std::cout << "InsertVector cost: " << std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
//
//
// std::cout << "InsertVector function call status: " << stat.ToString() << std::endl;
// std::cout << "Returned id array count: " << record_ids.size() << std::endl;
{//describe table
TableSchema tb_schema;
Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
std::cout << "DescribeTable function call status: " << stat.ToString() << std::endl;
PrintTableSchema(tb_schema);
}
//
// if(search_record_array.size() < NQ) {
// search_record_array.push_back(
// std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET]));
// }
// Connection::Destroy(conn);
// pid_t pid;
// for (int i = 0; i < 5; ++i) {
// pid = fork();
// if (pid == 0 || pid == -1) {
// break;
// }
// }
// if (pid == -1) {
// std::cout << "fail to fork!\n";
// exit(1);
// } else if (pid == 0) {
// std::shared_ptr<Connection> conn = Connection::Create();
//
// {//search vectors without index
// Sleep(2);
// DoSearch(conn, search_record_array, "Search without index");
// }
//
// {//wait unit build index finish
//// std::cout << "Wait until build all index done" << std::endl;
//// Status stat = conn->BuildIndex(TABLE_NAME);
//// std::cout << "BuildIndex function call status: " << stat.ToString() << std::endl;
// }
// {//connect server
// ConnectParam param = {address, port};
// Status stat = conn->Connect(param);
// std::cout << "Connect function call status: " << stat.ToString() << std::endl;
// }
//
// {//search vectors after build index finish
// DoSearch(conn, search_record_array, "Search after build index finish");
// }
// {//server version
// std::string version = conn->ServerVersion();
// std::cout << "Server version: " << version << std::endl;
// }
// Connection::Destroy(conn);
// exit(0);
// } else {
// std::shared_ptr<Connection> conn = Connection::Create();
//
// {//delete table
// Status stat = conn->DropTable(TABLE_NAME);
// std::cout << "DeleteTable function call status: " << stat.ToString() << std::endl;
// }
// {//connect server
// ConnectParam param = {address, port};
// Status stat = conn->Connect(param);
// std::cout << "Connect function call status: " << stat.ToString() << std::endl;
// }
//
// {//server status
// std::string status = conn->ServerStatus();
// std::cout << "Server status before disconnect: " << status << std::endl;
// }
// Connection::Destroy(conn);
//// conn->Disconnect();
// {//server status
// std::string status = conn->ServerStatus();
// std::cout << "Server status after disconnect: " << status << std::endl;
// {//server version
// std::string version = conn->ServerVersion();
// std::cout << "Server version: " << version << std::endl;
// }
// Connection::Destroy(conn);
// std::cout << "in main process\n";
// exit(0);
// }
std::vector<std::pair<int64_t, RowRecord>> search_record_array;
{//insert vectors
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors
std::vector<RowRecord> record_array;
int64_t begin_index = i * BATCH_ROW_COUNT;
BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array);
std::vector<int64_t> record_ids;
auto start = std::chrono::high_resolution_clock::now();
Status stat = conn->InsertVector(TABLE_NAME, record_array, record_ids);
auto finish = std::chrono::high_resolution_clock::now();
std::cout << "InsertVector cost: " << std::chrono::duration_cast<std::chrono::duration<double>>(finish - start).count() << "s\n";
std::cout << "InsertVector function call status: " << stat.ToString() << std::endl;
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
if(search_record_array.size() < NQ) {
search_record_array.push_back(
std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET]));
}
}
}
{//search vectors without index
Sleep(2);
DoSearch(conn, search_record_array, "Search without index");
}
{//wait unit build index finish
// std::cout << "Wait until build all index done" << std::endl;
// Status stat = conn->BuildIndex(TABLE_NAME);
// std::cout << "BuildIndex function call status: " << stat.ToString() << std::endl;
}
{//search vectors after build index finish
DoSearch(conn, search_record_array, "Search after build index finish");
}
{//delete table
Status stat = conn->DropTable(TABLE_NAME);
std::cout << "DeleteTable function call status: " << stat.ToString() << std::endl;
}
{//server status
std::string status = conn->ServerStatus();
std::cout << "Server status before disconnect: " << status << std::endl;
}
Connection::Destroy(conn);
// conn->Disconnect();
{//server status
std::string status = conn->ServerStatus();
std::cout << "Server status after disconnect: " << status << std::endl;
}
}
\ No newline at end of file
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#pragma once
#include <string>
#include <chrono>
namespace zilliz {
namespace knowhere {
class TimeRecorder {
using stdclock = std::chrono::high_resolution_clock;
public:
TimeRecorder(const std::string &header,
int64_t log_level = 0);
~TimeRecorder();//trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5
double RecordSection(const std::string &msg);
double ElapseFromBegin(const std::string &msg);
static std::string GetTimeSpanStr(double span);
private:
void PrintTimeRecord(const std::string &msg, double span);
private:
std::string header_;
stdclock::time_point start_;
stdclock::time_point last_;
int64_t log_level_;
};
}
}
......@@ -19,6 +19,12 @@ class GPUIVF : public IVF {
void SetGpuDevice(const int &gpu_id);
protected:
void search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg) override;
BinarySet SerializeImpl() override;
void LoadImpl(const BinarySet &index_binary) override;
......
......@@ -29,6 +29,7 @@ class BasicIndex {
std::shared_ptr<faiss::Index> index_ = nullptr;
};
using Graph = std::vector<std::vector<int64_t>>;
class IVF : public VectorIndex, public BasicIndex {
public:
......@@ -37,17 +38,24 @@ class IVF : public VectorIndex, public BasicIndex {
IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override;
void set_index_model(IndexModelPtr model) override;
void Add(const DatasetPtr &dataset, const Config &config) override;
void AddWithoutIds(const DatasetPtr &dataset, const Config &config);
DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
void GenGraph(const int64_t &k, Graph &graph, const DatasetPtr &dataset, const Config &config);
BinarySet Serialize() override;
void Load(const BinarySet &index_binary) override;
int64_t Count() override;
int64_t Dimension() override;
//DatasetPtr Search_twice(const DatasetPtr &dataset, const Config &config, float* xb);
protected:
virtual std::shared_ptr<faiss::IVFSearchParameters> GenParams(const Config &config);
virtual void search_impl(int64_t n,
const float *data,
int64_t k,
float *distances,
int64_t *labels,
const Config &cfg);
protected:
std::mutex mutex_;
};
......
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#pragma once
#include <mutex>
namespace zilliz {
namespace knowhere {
namespace algo {
using node_t = int64_t;
// TODO: search use simple neighbor
struct Neighbor {
node_t id; // offset of node in origin data
float distance;
bool has_explored;
Neighbor() = default;
explicit Neighbor(node_t id, float distance, bool f) : id{id}, distance{distance}, has_explored(f) {}
explicit Neighbor(node_t id, float distance) : id{id}, distance{distance}, has_explored(false) {}
inline bool operator<(const Neighbor &other) const {
return distance < other.distance;
}
};
//struct SimpleNeighbor {
// node_t id; // offset of node in origin data
// float distance;
//
// SimpleNeighbor() = default;
// explicit SimpleNeighbor(node_t id, float distance) : id{id}, distance{distance}{}
//
// inline bool operator<(const Neighbor &other) const {
// return distance < other.distance;
// }
//};
typedef std::lock_guard<std::mutex> LockGuard;
}
}
}
\ No newline at end of file
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#pragma once
#include <cstddef>
#include <vector>
#include <mutex>
#include <boost/dynamic_bitset.hpp>
#include "neighbor.h"
namespace zilliz {
namespace knowhere {
namespace algo {
using node_t = int64_t;
enum class MetricType {
METRIC_INNER_PRODUCT = 0,
METRIC_L2 = 1,
};
struct BuildParams {
size_t search_length;
size_t out_degree;
size_t candidate_pool_size;
};
struct SearchParams {
size_t search_length;
};
using Graph = std::vector<std::vector<node_t>>;
class NsgIndex {
public:
size_t dimension;
size_t ntotal; // totabl nb of indexed vectors
MetricType metric_type; // L2 | IP
float *ori_data_;
long *ids_; // TODO: support different type
Graph nsg; // final graph
Graph knng; // reset after build
node_t navigation_point; // offset of node in origin data
bool is_trained = false;
/*
* build and search parameter
*/
size_t search_length;
size_t candidate_pool_size; // search deepth in fullset
size_t out_degree;
public:
explicit NsgIndex(const size_t &dimension,
const size_t &n,
MetricType metric = MetricType::METRIC_L2);
NsgIndex() = default;
virtual ~NsgIndex();
void SetKnnGraph(Graph &knng);
virtual void Build_with_ids(size_t nb,
const float *data,
const long *ids,
const BuildParams &parameters);
void Search(const float *query,
const unsigned &nq,
const unsigned &dim,
const unsigned &k,
float *dist,
long *ids,
SearchParams &params);
// Not support yet.
//virtual void Add() = 0;
//virtual void Add_with_ids() = 0;
//virtual void Delete() = 0;
//virtual void Delete_with_ids() = 0;
//virtual void Rebuild(size_t nb,
// const float *data,
// const long *ids,
// const Parameters &parameters) = 0;
//virtual void Build(size_t nb,
// const float *data,
// const BuildParam &parameters);
protected:
virtual void InitNavigationPoint();
// link specify
void GetNeighbors(const float *query,
std::vector<Neighbor> &resset,
std::vector<Neighbor> &fullset,
boost::dynamic_bitset<> &has_calculated_dist);
// FindUnconnectedNode
void GetNeighbors(const float *query,
std::vector<Neighbor> &resset,
std::vector<Neighbor> &fullset);
// search and navigation-point
void GetNeighbors(const float *query,
std::vector<Neighbor> &resset,
Graph &graph,
SearchParams *param = nullptr);
void Link();
void SyncPrune(size_t q,
std::vector<Neighbor> &pool,
boost::dynamic_bitset<> &has_calculated,
float *cut_graph_dist
);
void SelectEdge(unsigned &cursor,
std::vector<Neighbor> &sort_pool,
std::vector<Neighbor> &result,
bool limit = false);
void InterInsert(unsigned n, std::vector<std::mutex> &mutex_vec, float *dist);
void CheckConnectivity();
void DFS(size_t root, boost::dynamic_bitset<> &flags, int64_t &count);
void FindUnconnectedNode(boost::dynamic_bitset<> &flags, int64_t &root);
private:
void GetKnnGraphFromFile();
};
}
}
}
%module nsg
%{
#define SWIG_FILE_WITH_INIT
#include <numpy/arrayobject.h>
/* Include the header in the wrapper code */
#include "nsg.h"
%}
/* Parse the header file */
%include "index.h"
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "nsg.h"
#include "knowhere/index/vector_index/ivf.h"
namespace zilliz {
namespace knowhere {
namespace algo {
extern void write_index(NsgIndex* index, MemoryIOWriter& writer);
extern NsgIndex* read_index(MemoryIOReader& reader);
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "knowhere/index/vector_index/vector_index.h"
namespace zilliz {
namespace knowhere {
namespace algo {
class NsgIndex;
}
class NSG : public VectorIndex {
public:
explicit NSG(const int64_t& gpu_num):gpu_(gpu_num){}
NSG() = default;
IndexModelPtr Train(const DatasetPtr &dataset, const Config &config) override;
DatasetPtr Search(const DatasetPtr &dataset, const Config &config) override;
void Add(const DatasetPtr &dataset, const Config &config) override;
BinarySet Serialize() override;
void Load(const BinarySet &index_binary) override;
int64_t Count() override;
int64_t Dimension() override;
private:
std::shared_ptr<algo::NsgIndex> index_;
int64_t gpu_;
};
using NSGIndexPtr = std::shared_ptr<NSG>();
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册