提交 1f39173f 编写于 作者: K kun yu

branch-0.4.0


Former-commit-id: b3b115e418b0313824ae103967ac791f7c78dd15
上级 bdf8ab7a
......@@ -655,6 +655,8 @@ ServerError PingTask::OnExecute() {
result_ = MILVUS_VERSION;
} else if (cmd_ == "disconnect") {
//TODO stopservice
} else {
result_ = "OK";
}
return SERVER_SUCCESS;
......
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
#define _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
#include "inc/Socket/Common.h"
#include "AggregatorSettings.h"
#include <memory>
#include <vector>
#include <atomic>
namespace SPTAG
{
namespace Aggregator
{
enum RemoteMachineStatus : uint8_t
{
Disconnected = 0,
Connecting,
Connected
};
struct RemoteMachine
{
RemoteMachine();
std::string m_address;
std::string m_port;
Socket::ConnectionID m_connectionID;
std::atomic<RemoteMachineStatus> m_status;
};
class AggregatorContext
{
public:
AggregatorContext(const std::string& p_filePath);
~AggregatorContext();
bool IsInitialized() const;
const std::vector<std::shared_ptr<RemoteMachine>>& GetRemoteServers() const;
const std::shared_ptr<AggregatorSettings>& GetSettings() const;
private:
std::vector<std::shared_ptr<RemoteMachine>> m_remoteServers;
std::shared_ptr<AggregatorSettings> m_settings;
bool m_initialized;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
#define _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
#include "inc/Socket/RemoteSearchQuery.h"
#include "inc/Socket/Packet.h"
#include <memory>
#include <atomic>
namespace SPTAG
{
namespace Aggregator
{
typedef std::shared_ptr<Socket::RemoteSearchResult> AggregatorResult;
class AggregatorExecutionContext
{
public:
AggregatorExecutionContext(std::size_t p_totalServerNumber,
Socket::PacketHeader p_requestHeader);
~AggregatorExecutionContext();
std::size_t GetServerNumber() const;
AggregatorResult& GetResult(std::size_t p_num);
const Socket::PacketHeader& GetRequestHeader() const;
bool IsCompletedAfterFinsh(std::uint32_t p_finishedCount);
private:
std::atomic<std::uint32_t> m_unfinishedCount;
std::vector<AggregatorResult> m_results;
Socket::PacketHeader m_requestHeader;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
#define _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
#include "AggregatorContext.h"
#include "AggregatorExecutionContext.h"
#include "inc/Socket/Server.h"
#include "inc/Socket/Client.h"
#include "inc/Socket/ResourceManager.h"
#include <boost/asio.hpp>
#include <memory>
#include <vector>
#include <thread>
#include <condition_variable>
namespace SPTAG
{
namespace Aggregator
{
class AggregatorService
{
public:
AggregatorService();
~AggregatorService();
bool Initialize();
void Run();
private:
void StartClient();
void StartListen();
void WaitForShutdown();
void ConnectToPendingServers();
void AddToPendingServers(std::shared_ptr<RemoteMachine> p_remoteServer);
void SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void AggregateResults(std::shared_ptr<AggregatorExecutionContext> p_exectionContext);
std::shared_ptr<AggregatorContext> GetContext();
private:
typedef std::function<void(Socket::RemoteSearchResult)> AggregatorCallback;
std::shared_ptr<AggregatorContext> m_aggregatorContext;
std::shared_ptr<Socket::Server> m_socketServer;
std::shared_ptr<Socket::Client> m_socketClient;
bool m_initalized;
std::unique_ptr<boost::asio::thread_pool> m_threadPool;
boost::asio::io_context m_ioContext;
boost::asio::signal_set m_shutdownSignals;
std::vector<std::shared_ptr<RemoteMachine>> m_pendingConnectServers;
std::mutex m_pendingConnectServersMutex;
boost::asio::deadline_timer m_pendingConnectServersTimer;
Socket::ResourceManager<AggregatorCallback> m_aggregatorCallbackManager;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
#define _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
#include "../Core/Common.h"
#include <string>
namespace SPTAG
{
namespace Aggregator
{
struct AggregatorSettings
{
AggregatorSettings();
std::string m_listenAddr;
std::string m_listenPort;
std::uint32_t m_searchTimeout;
SizeType m_threadNum;
SizeType m_socketThreadNum;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CLIENT_CLIENTWRAPPER_H_
#define _SPTAG_CLIENT_CLIENTWRAPPER_H_
#include "inc/Socket/Client.h"
#include "inc/Socket/RemoteSearchQuery.h"
#include "inc/Socket/ResourceManager.h"
#include "Options.h"
#include <string>
#include <vector>
#include <memory>
#include <atomic>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
namespace SPTAG
{
namespace Client
{
class ClientWrapper
{
public:
typedef std::function<void(Socket::RemoteSearchResult)> Callback;
ClientWrapper(const ClientOptions& p_options);
~ClientWrapper();
void SendQueryAsync(const Socket::RemoteQuery& p_query,
Callback p_callback,
const ClientOptions& p_options);
void WaitAllFinished();
bool IsAvailable() const;
private:
typedef std::pair<Socket::ConnectionID, Socket::ConnectionID> ConnectionPair;
Socket::PacketHandlerMapPtr GetHandlerMap();
void DecreaseUnfnishedJobCount();
const ConnectionPair& GetConnection();
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void HandleDeadConnection(Socket::ConnectionID p_cid);
private:
ClientOptions m_options;
std::unique_ptr<Socket::Client> m_client;
std::atomic<std::uint32_t> m_unfinishedJobCount;
std::atomic_bool m_isWaitingFinish;
std::condition_variable m_waitingQueue;
std::mutex m_waitingMutex;
std::vector<ConnectionPair> m_connections;
std::atomic<std::uint32_t> m_spinCountOfConnection;
Socket::ResourceManager<Callback> m_callbackManager;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_CLIENT_OPTIONS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CLIENT_OPTIONS_H_
#define _SPTAG_CLIENT_OPTIONS_H_
#include "inc/Helper/ArgumentsParser.h"
#include <string>
#include <vector>
#include <memory>
namespace SPTAG
{
namespace Client
{
class ClientOptions : public Helper::ArgumentsParser
{
public:
ClientOptions();
virtual ~ClientOptions();
std::string m_serverAddr;
std::string m_serverPort;
// in milliseconds.
std::uint32_t m_searchTimeout;
std::uint32_t m_threadNum;
std::uint32_t m_socketThreadNum;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_CLIENT_OPTIONS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_BKT_INDEX_H_
#define _SPTAG_BKT_INDEX_H_
#include "../Common.h"
#include "../VectorIndex.h"
#include "../Common/CommonUtils.h"
#include "../Common/DistanceUtils.h"
#include "../Common/QueryResultSet.h"
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/BKTree.h"
#include "inc/Helper/SimpleIniReader.h"
#include "inc/Helper/StringConvert.h"
#include <functional>
#include <mutex>
#include <tbb/concurrent_unordered_set.h>
namespace SPTAG
{
namespace Helper
{
class IniReader;
}
namespace BKT
{
template<typename T>
class Index : public VectorIndex
{
private:
// data points
COMMON::Dataset<T> m_pSamples;
// BKT structures.
COMMON::BKTree m_pTrees;
// Graph structure
COMMON::RelativeNeighborhoodGraph m_pGraph;
std::string m_sBKTFilename;
std::string m_sGraphFilename;
std::string m_sDataPointsFilename;
std::mutex m_dataLock; // protect data and graph
tbb::concurrent_unordered_set<int> m_deletedID;
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
int m_iMaxCheck;
int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
int m_iNumberOfInitialDynamicPivots;
int m_iNumberOfOtherDynamicPivots;
public:
Index()
{
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
VarName = DefaultValue; \
#include "inc/Core/BKT/ParameterDefinitionList.h"
#undef DefineBKTParameter
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
}
~Index() {}
inline int GetNumSamples() const { return m_pSamples.R(); }
inline int GetFeatureDim() const { return m_pSamples.C(); }
inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
inline int GetNumThreads() const { return m_iNumberOfThreads; }
inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::BKT; }
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; }
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen);
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs);
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout);
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader);
ErrorCode SearchIndex(QueryResult &p_query) const;
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const;
private:
ErrorCode RefineIndex(const std::string& p_folderPath);
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const;
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
};
} // namespace BKT
} // namespace SPTAG
#endif // _SPTAG_BKT_INDEX_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DefineBKTParameter
// DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr)
DefineBKTParameter(m_sBKTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
DefineBKTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
DefineBKTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
DefineBKTParameter(m_pTrees.m_iTreeNumber, int, 1L, "BKTNumber")
DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK")
DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize")
DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples")
DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TpTreeNumber")
DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
DefineBKTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTpTreeSplit")
DefineBKTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize")
DefineBKTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
DefineBKTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
DefineBKTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
DefineBKTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF")
DefineBKTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph")
DefineBKTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
DefineBKTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
DefineBKTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
DefineBKTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
DefineBKTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
DefineBKTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots")
#endif
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CORE_COMMONDEFS_H_
#define _SPTAG_CORE_COMMONDEFS_H_
#include <cstdint>
#include <type_traits>
#include <memory>
#include <string>
#include <limits>
#include <vector>
#include <cmath>
#ifndef _MSC_VER
#include <sys/stat.h>
#include <sys/types.h>
#define FolderSep '/'
#define mkdir(a) mkdir(a, ACCESSPERMS)
inline bool direxists(const char* path) {
struct stat info;
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR);
}
inline bool fileexists(const char* path) {
struct stat info;
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR) == 0;
}
template <class T>
inline T min(T a, T b) {
return a < b ? a : b;
}
template <class T>
inline T max(T a, T b) {
return a > b ? a : b;
}
#ifndef _rotl
#define _rotl(x, n) (((x) << (n)) | ((x) >> (32-(n))))
#endif
#else
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include <Psapi.h>
#define FolderSep '\\'
#define mkdir(a) CreateDirectory(a, NULL)
inline bool direxists(const char* path) {
auto dwAttr = GetFileAttributes((LPCSTR)path);
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY);
}
inline bool fileexists(const char* path) {
auto dwAttr = GetFileAttributes((LPCSTR)path);
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY) == 0;
}
#endif
namespace SPTAG
{
typedef std::uint32_t SizeType;
const float MinDist = (std::numeric_limits<float>::min)();
const float MaxDist = (std::numeric_limits<float>::max)();
const float Epsilon = 0.000000001f;
class MyException : public std::exception
{
private:
std::string Exp;
public:
MyException(std::string e) { Exp = e; }
#ifdef _MSC_VER
const char* what() const { return Exp.c_str(); }
#else
const char* what() const noexcept { return Exp.c_str(); }
#endif
};
// Type of number index.
typedef std::int32_t IndexType;
static_assert(std::is_integral<IndexType>::value, "IndexType must be integral type.");
enum class ErrorCode : std::uint16_t
{
#define DefineErrorCode(Name, Value) Name = Value,
#include "DefinitionList.h"
#undef DefineErrorCode
Undefined
};
static_assert(static_cast<std::uint16_t>(ErrorCode::Undefined) != 0, "Empty ErrorCode!");
enum class DistCalcMethod : std::uint8_t
{
#define DefineDistCalcMethod(Name) Name,
#include "DefinitionList.h"
#undef DefineDistCalcMethod
Undefined
};
static_assert(static_cast<std::uint8_t>(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!");
enum class VectorValueType : std::uint8_t
{
#define DefineVectorValueType(Name, Type) Name,
#include "DefinitionList.h"
#undef DefineVectorValueType
Undefined
};
static_assert(static_cast<std::uint8_t>(VectorValueType::Undefined) != 0, "Empty VectorValueType!");
enum class IndexAlgoType : std::uint8_t
{
#define DefineIndexAlgo(Name) Name,
#include "DefinitionList.h"
#undef DefineIndexAlgo
Undefined
};
static_assert(static_cast<std::uint8_t>(IndexAlgoType::Undefined) != 0, "Empty IndexAlgoType!");
template<typename T>
constexpr VectorValueType GetEnumValueType()
{
return VectorValueType::Undefined;
}
#define DefineVectorValueType(Name, Type) \
template<> \
constexpr VectorValueType GetEnumValueType<Type>() \
{ \
return VectorValueType::Name; \
} \
#include "DefinitionList.h"
#undef DefineVectorValueType
inline std::size_t GetValueTypeSize(VectorValueType p_valueType)
{
switch (p_valueType)
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
return sizeof(Type); \
#include "DefinitionList.h"
#undef DefineVectorValueType
default:
break;
}
return 0;
}
} // namespace SPTAG
#endif // _SPTAG_CORE_COMMONDEFS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_COMMONUTILS_H_
#define _SPTAG_COMMON_COMMONUTILS_H_
#include "../Common.h"
#include <unordered_map>
#include <fstream>
#include <iostream>
#include <exception>
#include <algorithm>
#include <time.h>
#include <omp.h>
#include <string.h>
#define PREFETCH
#ifndef _MSC_VER
#include <stdio.h>
#include <unistd.h>
#include <sys/resource.h>
#include <cstring>
#define InterlockedCompareExchange(a,b,c) __sync_val_compare_and_swap(a, c, b)
#define Sleep(a) usleep(a * 1000)
#define strtok_s(a, b, c) strtok_r(a, b, c)
#endif
namespace SPTAG
{
namespace COMMON
{
class Utils {
public:
static int rand_int(int high = RAND_MAX, int low = 0) // Generates a random int value.
{
return low + (int)(float(high - low)*(std::rand() / (RAND_MAX + 1.0)));
}
static inline float atomic_float_add(volatile float* ptr, const float operand)
{
union {
volatile long iOld;
float fOld;
};
union {
long iNew;
float fNew;
};
while (true) {
iOld = *(volatile long *)ptr;
fNew = fOld + operand;
if (InterlockedCompareExchange((long *)ptr, iNew, iOld) == iOld) {
return fNew;
}
}
}
static double GetVector(char* cstr, const char* sep, std::vector<float>& arr, int& NumDim) {
char* current;
char* context = NULL;
int i = 0;
double sum = 0;
arr.clear();
current = strtok_s(cstr, sep, &context);
while (current != NULL && (i < NumDim || NumDim < 0)) {
try {
float val = (float)atof(current);
arr.push_back(val);
}
catch (std::exception e) {
std::cout << "Exception:" << e.what() << std::endl;
return -2;
}
sum += arr[i] * arr[i];
current = strtok_s(NULL, sep, &context);
i++;
}
if (NumDim < 0) NumDim = i;
if (i < NumDim) return -2;
return std::sqrt(sum);
}
template <typename T>
static void Normalize(T* arr, int col, int base) {
double vecLen = 0;
for (int j = 0; j < col; j++) {
double val = arr[j];
vecLen += val * val;
}
vecLen = std::sqrt(vecLen);
if (vecLen < 1e-6) {
T val = (T)(1.0 / std::sqrt((double)col) * base);
for (int j = 0; j < col; j++) arr[j] = val;
}
else {
for (int j = 0; j < col; j++) arr[j] = (T)(arr[j] / vecLen * base);
}
}
static size_t ProcessLine(std::string& currentLine, std::vector<float>& arr, int& D, int base, DistCalcMethod distCalcMethod) {
size_t index;
double vecLen;
if (currentLine.length() == 0 || (index = currentLine.find_last_of("\t")) == std::string::npos || (vecLen = GetVector(const_cast<char*>(currentLine.c_str() + index + 1), "|", arr, D)) < -1) {
std::cout << "Parse vector error: " + currentLine << std::endl;
//throw MyException("Error in parsing data " + currentLine);
return -1;
}
if (distCalcMethod == DistCalcMethod::Cosine) {
Normalize(arr.data(), D, base);
}
return index;
}
template <typename T>
static void PrepareQuerys(std::ifstream& inStream, std::vector<std::string>& qString, std::vector<std::vector<T>>& Query, int& NumQuery, int& NumDim, DistCalcMethod distCalcMethod, int base) {
std::string currentLine;
std::vector<float> arr;
int i = 0;
size_t index;
while ((NumQuery < 0 || i < NumQuery) && !inStream.eof()) {
std::getline(inStream, currentLine);
if (currentLine.length() <= 1 || (index = ProcessLine(currentLine, arr, NumDim, base, distCalcMethod)) < 0) {
continue;
}
qString.push_back(currentLine.substr(0, index));
if (Query.size() < i + 1) Query.push_back(std::vector<T>(NumDim, 0));
for (int j = 0; j < NumDim; j++) Query[i][j] = (T)arr[j];
i++;
}
NumQuery = i;
std::cout << "Load data: (" << NumQuery << ", " << NumDim << ")" << std::endl;
}
template<typename T>
static inline int GetBase() {
if (GetEnumValueType<T>() != VectorValueType::Float) {
return (int)(std::numeric_limits<T>::max)();
}
return 1;
}
static inline void AddNeighbor(int idx, float dist, int *neighbors, float *dists, int size)
{
size--;
if (dist < dists[size] || (dist == dists[size] && idx < neighbors[size]))
{
int nb;
for (nb = 0; nb <= size && neighbors[nb] != idx; nb++);
if (nb > size)
{
nb = size;
while (nb > 0 && (dist < dists[nb - 1] || (dist == dists[nb - 1] && idx < neighbors[nb - 1])))
{
dists[nb] = dists[nb - 1];
neighbors[nb] = neighbors[nb - 1];
nb--;
}
dists[nb] = dist;
neighbors[nb] = idx;
}
}
}
};
}
}
#endif // _SPTAG_COMMON_COMMONUTILS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_DATAUTILS_H_
#define _SPTAG_COMMON_DATAUTILS_H_
#include <sys/stat.h>
#include <atomic>
#include "CommonUtils.h"
#include "../../Helper/CommonHelper.h"
namespace SPTAG
{
namespace COMMON
{
const int bufsize = 1024 * 1024 * 1024;
class DataUtils {
public:
template <typename T>
static void ProcessTSVData(int id, int threadbase, std::uint64_t blocksize,
std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) {
std::ifstream inputStream(filename);
if (!inputStream.is_open()) {
std::cerr << "unable to open file " + filename << std::endl;
throw MyException("unable to open file " + filename);
exit(1);
}
std::ofstream outputStream, metaStream_out, metaStream_index;
outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary);
if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) {
std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl;
throw MyException("unable to open output files");
exit(1);
}
std::vector<float> arr;
std::vector<T> sample;
int base = 1;
if (distCalcMethod == DistCalcMethod::Cosine) {
base = Utils::GetBase<T>();
}
std::uint64_t writepos = 0;
int sampleSize = 0;
std::uint64_t totalread = 0;
std::streamoff startpos = id * blocksize;
#ifndef _MSC_VER
int enter_size = 1;
#else
int enter_size = 1;
#endif
std::string currentLine;
size_t index;
inputStream.seekg(startpos, std::ifstream::beg);
if (id != 0) {
std::getline(inputStream, currentLine);
totalread += currentLine.length() + enter_size;
}
std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl;
while (!inputStream.eof() && totalread <= blocksize) {
std::getline(inputStream, currentLine);
if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) {
totalread += currentLine.length() + enter_size;
continue;
}
sample.resize(D);
for (int j = 0; j < D; j++) sample[j] = (T)arr[j];
outputStream.write((char *)(sample.data()), sizeof(T)*D);
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_out.write(currentLine.c_str(), index);
writepos += index;
sampleSize += 1;
totalread += currentLine.length() + enter_size;
}
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_index.write((char *)&sampleSize, sizeof(int));
inputStream.close();
outputStream.close();
metaStream_out.close();
metaStream_index.close();
numSamples.fetch_add(sampleSize);
std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl;
}
static void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int D) {
std::ifstream inputStream;
std::ofstream outputStream;
char * buf = new char[bufsize];
std::uint64_t * offsets;
int partSamples;
int metaSamples = 0;
std::uint64_t lastoff = 0;
outputStream.open(outfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
outputStream.write((char *)&D, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
outputStream.open(outmetafile, std::ofstream::binary);
for (int i = 0; i < threadbase; i++) {
std::string file = outmetafile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
delete[] buf;
outputStream.open(outmetaindexfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outmetaindexfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
inputStream.seekg(-((long long)sizeof(int)), inputStream.end);
inputStream.read((char *)&partSamples, sizeof(int));
offsets = new std::uint64_t[partSamples + 1];
inputStream.seekg(0, inputStream.beg);
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1));
inputStream.close();
remove(file.c_str());
for (int j = 0; j < partSamples + 1; j++)
offsets[j] += lastoff;
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples);
lastoff = offsets[partSamples];
metaSamples += partSamples;
delete[] offsets;
}
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
outputStream.close();
std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl;
}
static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1,
const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) {
std::ifstream inputStream1, inputStream2;
std::ofstream outputStream;
char * buf = new char[bufsize];
int R1, R2, C1, C2;
#define MergeVector(inputStream, vectorFile, R, C) \
inputStream.open(vectorFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \
return false; \
} \
inputStream.read((char *)&(R), sizeof(int)); \
inputStream.read((char *)&(C), sizeof(int)); \
MergeVector(inputStream1, p_vectorfile1, R1, C1)
MergeVector(inputStream2, p_vectorfile2, R2, C2)
#undef MergeVector
if (C1 != C2) {
inputStream1.close(); inputStream2.close();
std::cout << "Vector dimensions are not the same!" << std::endl;
return false;
}
R1 += R2;
outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int));
outputStream.write((char *)&C1, sizeof(int));
while (!inputStream1.eof()) {
inputStream1.read(buf, bufsize);
outputStream.write(buf, inputStream1.gcount());
}
while (!inputStream2.eof()) {
inputStream2.read(buf, bufsize);
outputStream.write(buf, inputStream2.gcount());
}
inputStream1.close(); inputStream2.close();
outputStream.close();
if (p_metafile1 != "" && p_metafile2 != "") {
outputStream.open(p_metafile1 + "_tmp", std::ofstream::binary);
#define MergeMeta(inputStream, metaFile) \
inputStream.open(metaFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open meta file: " << metaFile << "!" << std::endl; \
return false; \
} \
while (!inputStream.eof()) { \
inputStream.read(buf, bufsize); \
outputStream.write(buf, inputStream.gcount()); \
} \
inputStream.close(); \
MergeMeta(inputStream1, p_metafile1)
MergeMeta(inputStream2, p_metafile2)
#undef MergeMeta
outputStream.close();
delete[] buf;
std::uint64_t * offsets;
int partSamples;
std::uint64_t lastoff = 0;
outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int));
#define MergeMetaIndex(inputStream, metaIndexFile) \
inputStream.open(metaIndexFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \
return false; \
} \
inputStream.read((char *)&partSamples, sizeof(int)); \
offsets = new std::uint64_t[partSamples + 1]; \
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); \
inputStream.close(); \
for (int j = 0; j < partSamples + 1; j++) offsets[j] += lastoff; \
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); \
lastoff = offsets[partSamples]; \
delete[] offsets; \
MergeMetaIndex(inputStream1, p_metaindexfile1)
MergeMetaIndex(inputStream2, p_metaindexfile2)
#undef MergeMetaIndex
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
outputStream.close();
rename((p_metafile1 + "_tmp").c_str(), p_metafile1.c_str());
rename((p_metaindexfile1 + "_tmp").c_str(), p_metaindexfile1.c_str());
}
rename((p_vectorfile1 + "_tmp").c_str(), p_vectorfile1.c_str());
std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl;
return true;
}
template <typename T>
static void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
int threadnum, DistCalcMethod distCalcMethod) {
omp_set_num_threads(threadnum);
std::atomic_int numSamples = { 0 };
int D = -1;
int threadbase = 0;
std::vector<std::string> inputFileNames = Helper::StrUtils::SplitString(filenames, ",");
for (std::string inputFileName : inputFileNames)
{
#ifndef _MSC_VER
struct stat stat_buf;
stat(inputFileName.c_str(), &stat_buf);
#else
struct _stat64 stat_buf;
int res = _stat64(inputFileName.c_str(), &stat_buf);
#endif
std::uint64_t blocksize = (stat_buf.st_size + threadnum - 1) / threadnum;
#pragma omp parallel for
for (int i = 0; i < threadnum; i++) {
ProcessTSVData<T>(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod);
}
threadbase += threadnum;
}
MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D);
}
};
}
}
#endif // _SPTAG_COMMON_DATAUTILS_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_DATASET_H_
#define _SPTAG_COMMON_DATASET_H_
#include <fstream>
#if defined(_MSC_VER) || defined(__INTEL_COMPILER)
#include <malloc.h>
#else
#include <mm_malloc.h>
#endif // defined(__GNUC__)
#define ALIGN 32
#define aligned_malloc(a, b) _mm_malloc(a, b)
#define aligned_free(a) _mm_free(a)
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
namespace SPTAG
{
namespace COMMON
{
// structure to save Data and Graph
template <typename T>
class Dataset
{
private:
int rows;
int cols;
bool ownData = false;
T* data = nullptr;
std::vector<T> dataIncremental;
public:
Dataset(): rows(0), cols(1) {}
Dataset(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{
Initialize(rows_, cols_, data_, transferOnwership_);
}
~Dataset()
{
if (ownData) aligned_free(data);
}
void Initialize(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{
rows = rows_;
cols = cols_;
data = data_;
if (data_ == nullptr || !transferOnwership_)
{
ownData = true;
data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN);
if (data_ != nullptr) memcpy(data, data_, rows * cols * sizeof(T));
else std::memset(data, -1, rows * cols * sizeof(T));
}
}
void SetR(int R_)
{
if (R_ >= rows)
dataIncremental.resize((R_ - rows) * cols);
else
{
rows = R_;
dataIncremental.clear();
}
}
inline int R() const { return (int)(rows + dataIncremental.size() / cols); }
inline int C() const { return cols; }
T* operator[](int index)
{
if (index >= rows) {
return dataIncremental.data() + (size_t)(index - rows)*cols;
}
return data + (size_t)index*cols;
}
const T* operator[](int index) const
{
if (index >= rows) {
return dataIncremental.data() + (size_t)(index - rows)*cols;
}
return data + (size_t)index*cols;
}
void AddBatch(const T* pData, int num)
{
dataIncremental.insert(dataIncremental.end(), pData, pData + num*cols);
}
void AddBatch(int num)
{
dataIncremental.insert(dataIncremental.end(), (size_t)num*cols, T(-1));
}
bool Save(std::string sDataPointsFileName)
{
std::cout << "Save Data To " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
if (fp == NULL) return false;
int CR = R();
fwrite(&CR, sizeof(int), 1, fp);
fwrite(&cols, sizeof(int), 1, fp);
T* ptr = data;
int toWrite = rows;
while (toWrite > 0)
{
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
ptr += write * cols;
toWrite -= (int)write;
}
ptr = dataIncremental.data();
toWrite = CR - rows;
while (toWrite > 0)
{
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
ptr += write * cols;
toWrite -= (int)write;
}
fclose(fp);
std::cout << "Save Data (" << CR << ", " << cols << ") Finish!" << std::endl;
return true;
}
bool Save(void **pDataPointsMemFile, int64_t &len)
{
size_t size = sizeof(int) + sizeof(int) + sizeof(T) * R() *cols;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
int CR = R();
auto header = (int*)mem;
header[0] = CR;
header[1] = cols;
auto body = &mem[8];
memcpy(body, data, sizeof(T) * cols * rows);
body += sizeof(T) * cols * rows;
memcpy(body, dataIncremental.data(), sizeof(T) * cols * (CR - rows));
body += sizeof(T) * cols * (CR - rows);
*pDataPointsMemFile = mem;
len = size;
return true;
}
bool Load(std::string sDataPointsFileName)
{
std::cout << "Load Data From " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
if (fp == NULL) return false;
int R, C;
fread(&R, sizeof(int), 1, fp);
fread(&C, sizeof(int), 1, fp);
Initialize(R, C);
T* ptr = data;
while (R > 0) {
size_t read = fread(ptr, sizeof(T) * C, R, fp);
ptr += read * C;
R -= (int)read;
}
fclose(fp);
std::cout << "Load Data (" << rows << ", " << cols << ") Finish!" << std::endl;
return true;
}
// Functions for loading models from memory mapped files
bool Load(char* pDataPointsMemFile)
{
int R, C;
R = *((int*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(int);
C = *((int*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(int);
Initialize(R, C, (T*)pDataPointsMemFile);
return true;
}
bool Refine(const std::vector<int>& indices, std::string sDataPointsFileName)
{
std::cout << "Save Refine Data To " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
if (fp == NULL) return false;
int R = (int)(indices.size());
fwrite(&R, sizeof(int), 1, fp);
fwrite(&cols, sizeof(int), 1, fp);
// write point one by one in case for cache miss
for (int i = 0; i < R; i++) {
if (indices[i] < rows)
fwrite(data + (size_t)indices[i] * cols, sizeof(T) * cols, 1, fp);
else
fwrite(dataIncremental.data() + (size_t)(indices[i] - rows) * cols, sizeof(T) * cols, 1, fp);
}
fclose(fp);
std::cout << "Save Refine Data (" << R << ", " << cols << ") Finish!" << std::endl;
return true;
}
};
}
}
#endif // _SPTAG_COMMON_DATASET_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_FINEGRAINEDLOCK_H_
#define _SPTAG_COMMON_FINEGRAINEDLOCK_H_
#include <vector>
#include <mutex>
#include <memory>
namespace SPTAG
{
namespace COMMON
{
class FineGrainedLock {
public:
FineGrainedLock() {}
~FineGrainedLock() {
for (int i = 0; i < locks.size(); i++)
locks[i].reset();
locks.clear();
}
void resize(int n) {
int current = (int)locks.size();
if (current <= n) {
locks.resize(n);
for (int i = current; i < n; i++)
locks[i].reset(new std::mutex);
}
else {
for (int i = n; i < current; i++)
locks[i].reset();
locks.resize(n);
}
}
std::mutex& operator[](int idx) {
return *locks[idx];
}
const std::mutex& operator[](int idx) const {
return *locks[idx];
}
private:
std::vector<std::shared_ptr<std::mutex>> locks;
};
}
}
#endif // _SPTAG_COMMON_FINEGRAINEDLOCK_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_HEAP_H_
#define _SPTAG_COMMON_HEAP_H_
namespace SPTAG
{
namespace COMMON
{
// priority queue
template <typename T>
class Heap {
public:
Heap() : heap(nullptr), length(0), count(0) {}
Heap(int size) { Resize(size); }
void Resize(int size)
{
length = size;
heap.reset(new T[length + 1]); // heap uses 1-based indexing
count = 0;
lastlevel = int(pow(2.0, floor(log2(size))));
}
~Heap() {}
inline int size() { return count; }
inline bool empty() { return count == 0; }
inline void clear() { count = 0; }
inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; }
// Insert a new element in the heap.
void insert(T value)
{
/* If heap is full, then return without adding this element. */
int loc;
if (count == length) {
int maxi = lastlevel;
for (int i = lastlevel + 1; i <= length; i++)
if (heap[maxi] < heap[i]) maxi = i;
if (value > heap[maxi]) return;
loc = maxi;
}
else {
loc = ++(count); /* Remember 1-based indexing. */
}
/* Keep moving parents down until a place is found for this node. */
int par = (loc >> 1); /* Location of parent. */
while (par > 0 && value < heap[par]) {
heap[loc] = heap[par]; /* Move parent down to loc. */
loc = par;
par >>= 1;
}
/* Insert the element at the determined location. */
heap[loc] = value;
}
// Returns the node of minimum value from the heap (top of the heap).
bool pop(T& value)
{
if (count == 0) return false;
/* Switch first node with last. */
value = heap[1];
std::swap(heap[1], heap[count]);
count--;
heapify(); /* Move new node 1 to right position. */
return true; /* Return old last node. */
}
T& pop()
{
if (count == 0) return heap[0];
/* Switch first node with last. */
std::swap(heap[1], heap[count]);
count--;
heapify(); /* Move new node 1 to right position. */
return heap[count + 1]; /* Return old last node. */
}
private:
// Storage array for the heap.
// Type T must be comparable.
std::unique_ptr<T[]> heap;
int length;
int count; // Number of element in the heap
int lastlevel;
// Reorganizes the heap (a parent is smaller than its children) starting with a node.
void heapify()
{
int parent = 1, next = 2;
while (next < count) {
if (heap[next] > heap[next + 1]) next++;
if (heap[next] < heap[parent]) {
std::swap(heap[parent], heap[next]);
parent = next;
next <<= 1;
}
else break;
}
if (next == count && heap[next] < heap[parent]) std::swap(heap[parent], heap[next]);
}
};
}
}
#endif // _SPTAG_COMMON_HEAP_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_KDTREE_H_
#define _SPTAG_COMMON_KDTREE_H_
#include <iostream>
#include <vector>
#include <string>
#include "../VectorIndex.h"
#include "CommonUtils.h"
#include "QueryResultSet.h"
#include "WorkSpace.h"
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
namespace SPTAG
{
namespace COMMON
{
// node type for storing KDT
struct KDTNode
{
int left;
int right;
short split_dim;
float split_value;
};
class KDTree
{
public:
KDTree() : m_iTreeNumber(2), m_numTopDimensionKDTSplit(5), m_iSamples(1000) {}
KDTree(KDTree& other) : m_iTreeNumber(other.m_iTreeNumber),
m_numTopDimensionKDTSplit(other.m_numTopDimensionKDTSplit),
m_iSamples(other.m_iSamples) {}
~KDTree() {}
inline const KDTNode& operator[](int index) const { return m_pTreeRoots[index]; }
inline KDTNode& operator[](int index) { return m_pTreeRoots[index]; }
inline int size() const { return (int)m_pTreeRoots.size(); }
template <typename T>
void BuildTrees(VectorIndex* p_index, std::vector<int>* indices = nullptr)
{
std::vector<int> localindices;
if (indices == nullptr) {
localindices.resize(p_index->GetNumSamples());
for (int i = 0; i < p_index->GetNumSamples(); i++) localindices[i] = i;
}
else {
localindices.assign(indices->begin(), indices->end());
}
m_pTreeRoots.resize(m_iTreeNumber * localindices.size());
m_pTreeStart.resize(m_iTreeNumber, 0);
#pragma omp parallel for
for (int i = 0; i < m_iTreeNumber; i++)
{
Sleep(i * 100); std::srand(clock());
std::vector<int> pindices(localindices.begin(), localindices.end());
std::random_shuffle(pindices.begin(), pindices.end());
m_pTreeStart[i] = i * (int)pindices.size();
std::cout << "Start to build KDTree " << i + 1 << std::endl;
int iTreeSize = m_pTreeStart[i];
DivideTree<T>(p_index, pindices, 0, (int)pindices.size() - 1, m_pTreeStart[i], iTreeSize);
std::cout << i + 1 << " KDTree built, " << iTreeSize - m_pTreeStart[i] << " " << pindices.size() << std::endl;
}
}
bool SaveTrees(void **pKDTMemFile, int64_t &len) const
{
int treeNodeSize = (int)m_pTreeRoots.size();
size_t size = sizeof(int) +
sizeof(int) * m_iTreeNumber +
sizeof(int) +
sizeof(KDTNode) * treeNodeSize;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
auto ptr = mem;
*(int*)ptr = m_iTreeNumber;
ptr += sizeof(int);
memcpy(ptr, m_pTreeStart.data(), sizeof(int) * m_iTreeNumber);
ptr += sizeof(int) * m_iTreeNumber;
*(int*)ptr = treeNodeSize;
ptr += sizeof(int);
memcpy(ptr, m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize);
*pKDTMemFile = mem;
len = size;
return true;
}
bool SaveTrees(std::string sTreeFileName) const
{
std::cout << "Save KDT to " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "wb");
if (fp == NULL) return false;
fwrite(&m_iTreeNumber, sizeof(int), 1, fp);
fwrite(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize = (int)m_pTreeRoots.size();
fwrite(&treeNodeSize, sizeof(int), 1, fp);
fwrite(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Save KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
bool LoadTrees(char* pKDTMemFile)
{
m_iTreeNumber = *((int*)pKDTMemFile);
pKDTMemFile += sizeof(int);
m_pTreeStart.resize(m_iTreeNumber);
memcpy(m_pTreeStart.data(), pKDTMemFile, sizeof(int) * m_iTreeNumber);
pKDTMemFile += sizeof(int)*m_iTreeNumber;
int treeNodeSize = *((int*)pKDTMemFile);
pKDTMemFile += sizeof(int);
m_pTreeRoots.resize(treeNodeSize);
memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize);
return true;
}
bool LoadTrees(std::string sTreeFileName)
{
std::cout << "Load KDT From " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "rb");
if (fp == NULL) return false;
fread(&m_iTreeNumber, sizeof(int), 1, fp);
m_pTreeStart.resize(m_iTreeNumber);
fread(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize;
fread(&treeNodeSize, sizeof(int), 1, fp);
m_pTreeRoots.resize(treeNodeSize);
fread(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
template <typename T>
void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
{
for (char i = 0; i < m_iTreeNumber; i++) {
KDTSearch(p_index, p_query, p_space, m_pTreeStart[i], true, 0);
}
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
{
auto& tcell = p_space.m_SPTQueue.pop();
if (p_query.worstDist() < tcell.distance) break;
KDTSearch(p_index, p_query, p_space, tcell.node, true, tcell.distance);
}
}
template <typename T>
void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
{
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
{
auto& tcell = p_space.m_SPTQueue.pop();
KDTSearch(p_index, p_query, p_space, tcell.node, false, tcell.distance);
}
}
private:
template <typename T>
void KDTSearch(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query,
COMMON::WorkSpace& p_space, const int node, const bool isInit, const float distBound) const {
if (node < 0)
{
int index = -node - 1;
if (index >= p_index->GetNumSamples()) return;
#ifdef PREFETCH
const char* data = (const char *)(p_index->GetSample(index));
_mm_prefetch(data, _MM_HINT_T0);
_mm_prefetch(data + 64, _MM_HINT_T0);
#endif
if (p_space.CheckAndSet(index)) return;
++p_space.m_iNumberOfTreeCheckedLeaves;
++p_space.m_iNumberOfCheckedLeaves;
p_space.m_NGQueue.insert(COMMON::HeapCell(index, p_index->ComputeDistance((const void*)p_query.GetTarget(), (const void*)data)));
return;
}
auto& tnode = m_pTreeRoots[node];
float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value;
float distanceBound = distBound + diff * diff;
int otherChild, bestChild;
if (diff < 0)
{
bestChild = tnode.left;
otherChild = tnode.right;
}
else
{
otherChild = tnode.left;
bestChild = tnode.right;
}
if (!isInit || distanceBound < p_query.worstDist())
{
p_space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound));
}
KDTSearch(p_index, p_query, p_space, bestChild, isInit, distBound);
}
template <typename T>
void DivideTree(VectorIndex* p_index, std::vector<int>& indices, int first, int last,
int index, int &iTreeSize) {
ChooseDivision<T>(p_index, m_pTreeRoots[index], indices, first, last);
int i = Subdivide<T>(p_index, m_pTreeRoots[index], indices, first, last);
if (i - 1 <= first)
{
m_pTreeRoots[index].left = -indices[first] - 1;
}
else
{
iTreeSize++;
m_pTreeRoots[index].left = iTreeSize;
DivideTree<T>(p_index, indices, first, i - 1, iTreeSize, iTreeSize);
}
if (last == i)
{
m_pTreeRoots[index].right = -indices[last] - 1;
}
else
{
iTreeSize++;
m_pTreeRoots[index].right = iTreeSize;
DivideTree<T>(p_index, indices, i, last, iTreeSize, iTreeSize);
}
}
template <typename T>
void ChooseDivision(VectorIndex* p_index, KDTNode& node, const std::vector<int>& indices, const int first, const int last)
{
std::vector<float> meanValues(p_index->GetFeatureDim(), 0);
std::vector<float> varianceValues(p_index->GetFeatureDim(), 0);
int end = min(first + m_iSamples, last);
int count = end - first + 1;
// calculate the mean of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
meanValues[k] += v[k];
}
}
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
meanValues[k] /= count;
}
// calculate the variance of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
float dist = v[k] - meanValues[k];
varianceValues[k] += dist*dist;
}
}
// choose the split dimension as one of the dimension inside TOP_DIM maximum variance
node.split_dim = SelectDivisionDimension(varianceValues);
// determine the threshold
node.split_value = meanValues[node.split_dim];
}
int SelectDivisionDimension(const std::vector<float>& varianceValues) const
{
// Record the top maximum variances
std::vector<int> topind(m_numTopDimensionKDTSplit);
int num = 0;
// order the variances
for (int i = 0; i < varianceValues.size(); i++)
{
if (num < m_numTopDimensionKDTSplit || varianceValues[i] > varianceValues[topind[num - 1]])
{
if (num < m_numTopDimensionKDTSplit)
{
topind[num++] = i;
}
else
{
topind[num - 1] = i;
}
int j = num - 1;
// order the TOP_DIM variances
while (j > 0 && varianceValues[topind[j]] > varianceValues[topind[j - 1]])
{
std::swap(topind[j], topind[j - 1]);
j--;
}
}
}
// randomly choose a dimension from TOP_DIM
return topind[COMMON::Utils::rand_int(num)];
}
template <typename T>
int Subdivide(VectorIndex* p_index, const KDTNode& node, std::vector<int>& indices, const int first, const int last) const
{
int i = first;
int j = last;
// decide which child one point belongs
while (i <= j)
{
int ind = indices[i];
const T* v = (const T*)p_index->GetSample(ind);
float val = v[node.split_dim];
if (val < node.split_value)
{
i++;
}
else
{
std::swap(indices[i], indices[j]);
j--;
}
}
// if all the points in the node are equal,equally split the node into 2
if ((i == first) || (i == last + 1))
{
i = (first + last + 1) / 2;
}
return i;
}
private:
std::vector<int> m_pTreeStart;
std::vector<KDTNode> m_pTreeRoots;
public:
int m_iTreeNumber, m_numTopDimensionKDTSplit, m_iSamples;
};
}
}
#endif
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_NG_H_
#define _SPTAG_COMMON_NG_H_
#include "../VectorIndex.h"
#include "CommonUtils.h"
#include "Dataset.h"
#include "FineGrainedLock.h"
#include "QueryResultSet.h"
namespace SPTAG
{
namespace COMMON
{
class NeighborhoodGraph
{
public:
NeighborhoodGraph(): m_iTPTNumber(32),
m_iTPTLeafSize(2000),
m_iSamples(1000),
m_numTopDimensionTPTSplit(5),
m_iNeighborhoodSize(32),
m_iNeighborhoodScale(2),
m_iCEFScale(2),
m_iRefineIter(0),
m_iCEF(1000),
m_iMaxCheckForRefineGraph(10000) {}
~NeighborhoodGraph() {}
virtual void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist) = 0;
virtual void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) = 0;
virtual float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr) = 0;
template <typename T>
void BuildGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr)
{
std::cout << "build RNG graph!" << std::endl;
m_iGraphSize = index->GetNumSamples();
m_iNeighborhoodSize = m_iNeighborhoodSize * m_iNeighborhoodScale;
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
m_dataUpdateLock.resize(m_iGraphSize);
if (m_iGraphSize < 1000) {
RefineGraph<T>(index, idmap);
std::cout << "Build RNG Graph end!" << std::endl;
return;
}
{
COMMON::Dataset<float> NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize);
std::vector<std::vector<int>> TptreeDataIndices(m_iTPTNumber, std::vector<int>(m_iGraphSize));
std::vector<std::vector<std::pair<int, int>>> TptreeLeafNodes(m_iTPTNumber, std::vector<std::pair<int, int>>());
for (int i = 0; i < m_iGraphSize; i++)
for (int j = 0; j < m_iNeighborhoodSize; j++)
(NeighborhoodDists)[i][j] = MaxDist;
std::cout << "Parallel TpTree Partition begin " << std::endl;
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iTPTNumber; i++)
{
Sleep(i * 100); std::srand(clock());
for (int j = 0; j < m_iGraphSize; j++) TptreeDataIndices[i][j] = j;
std::random_shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end());
PartitionByTptree<T>(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, TptreeLeafNodes[i]);
std::cout << "Finish Getting Leaves for Tree " << i << std::endl;
}
std::cout << "Parallel TpTree Partition done" << std::endl;
for (int i = 0; i < m_iTPTNumber; i++)
{
#pragma omp parallel for schedule(dynamic)
for (int j = 0; j < TptreeLeafNodes[i].size(); j++)
{
int start_index = TptreeLeafNodes[i][j].first;
int end_index = TptreeLeafNodes[i][j].second;
if (omp_get_thread_num() == 0) std::cout << "\rProcessing Tree " << i << ' ' << j * 100 / TptreeLeafNodes[i].size() << '%';
for (int x = start_index; x < end_index; x++)
{
for (int y = x + 1; y <= end_index; y++)
{
int p1 = TptreeDataIndices[i][x];
int p2 = TptreeDataIndices[i][y];
float dist = index->ComputeDistance(index->GetSample(p1), index->GetSample(p2));
if (idmap != nullptr) {
p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1);
p2 = (idmap->find(p2) == idmap->end()) ? p2 : idmap->at(p2);
}
COMMON::Utils::AddNeighbor(p2, dist, (m_pNeighborhoodGraph)[p1], (NeighborhoodDists)[p1], m_iNeighborhoodSize);
COMMON::Utils::AddNeighbor(p1, dist, (m_pNeighborhoodGraph)[p2], (NeighborhoodDists)[p2], m_iNeighborhoodSize);
}
}
}
TptreeDataIndices[i].clear();
TptreeLeafNodes[i].clear();
std::cout << std::endl;
}
TptreeDataIndices.clear();
TptreeLeafNodes.clear();
}
if (m_iMaxCheckForRefineGraph > 0) {
RefineGraph<T>(index, idmap);
}
}
template <typename T>
void RefineGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr)
{
m_iCEF *= m_iCEFScale;
m_iMaxCheckForRefineGraph *= m_iCEFScale;
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iGraphSize; i++)
{
RefineNode<T>(index, i, false);
if (i % 1000 == 0) std::cout << "\rRefine 1 " << (i * 100 / m_iGraphSize) << "%";
}
std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl;
m_iCEF /= m_iCEFScale;
m_iMaxCheckForRefineGraph /= m_iCEFScale;
m_iNeighborhoodSize /= m_iNeighborhoodScale;
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iGraphSize; i++)
{
RefineNode<T>(index, i, false);
if (i % 1000 == 0) std::cout << "\rRefine 2 " << (i * 100 / m_iGraphSize) << "%";
}
std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl;
if (idmap != nullptr) {
for (auto iter = idmap->begin(); iter != idmap->end(); iter++)
if (iter->first < 0)
{
m_pNeighborhoodGraph[-1 - iter->first][m_iNeighborhoodSize - 1] = -2 - iter->second;
}
}
}
template <typename T>
ErrorCode RefineGraph(VectorIndex* index, std::vector<int>& indices, std::vector<int>& reverseIndices,
std::string graphFileName, const std::unordered_map<int, int>* idmap = nullptr)
{
int R = (int)indices.size();
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < R; i++)
{
RefineNode<T>(index, indices[i], false);
int* nodes = m_pNeighborhoodGraph[indices[i]];
for (int j = 0; j < m_iNeighborhoodSize; j++)
{
if (nodes[j] < 0) nodes[j] = -1;
else nodes[j] = reverseIndices[nodes[j]];
}
if (idmap == nullptr || idmap->find(-1 - indices[i]) == idmap->end()) continue;
nodes[m_iNeighborhoodSize - 1] = -2 - idmap->at(-1 - indices[i]);
}
std::ofstream graphOut(graphFileName, std::ios::binary);
if (!graphOut.is_open()) return ErrorCode::FailedCreateFile;
graphOut.write((char*)&R, sizeof(int));
graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int));
for (int i = 0; i < R; i++) {
graphOut.write((char*)m_pNeighborhoodGraph[indices[i]], sizeof(int) * m_iNeighborhoodSize);
}
graphOut.close();
return ErrorCode::Success;
}
template <typename T>
void RefineNode(VectorIndex* index, const int node, bool updateNeighbors)
{
COMMON::QueryResultSet<T> query((const T*)index->GetSample(node), m_iCEF + 1);
index->SearchIndex(query);
RebuildNeighbors(index, node, m_pNeighborhoodGraph[node], query.GetResults(), m_iCEF + 1);
if (updateNeighbors) {
// update neighbors
for (int j = 0; j <= m_iCEF; j++)
{
BasicResult* item = query.GetResult(j);
if (item->VID < 0) break;
if (item->VID == node) continue;
std::lock_guard<std::mutex> lock(m_dataUpdateLock[item->VID]);
InsertNeighbors(index, item->VID, node, item->Dist);
}
}
}
template <typename T>
void PartitionByTptree(VectorIndex* index, std::vector<int>& indices, const int first, const int last,
std::vector<std::pair<int, int>> & leaves)
{
if (last - first <= m_iTPTLeafSize)
{
leaves.push_back(std::make_pair(first, last));
}
else
{
std::vector<float> Mean(index->GetFeatureDim(), 0);
int iIteration = 100;
int end = min(first + m_iSamples, last);
int count = end - first + 1;
// calculate the mean of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)index->GetSample(indices[j]);
for (int k = 0; k < index->GetFeatureDim(); k++)
{
Mean[k] += v[k];
}
}
for (int k = 0; k < index->GetFeatureDim(); k++)
{
Mean[k] /= count;
}
std::vector<BasicResult> Variance;
Variance.reserve(index->GetFeatureDim());
for (int j = 0; j < index->GetFeatureDim(); j++)
{
Variance.push_back(BasicResult(j, 0));
}
// calculate the variance of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)index->GetSample(indices[j]);
for (int k = 0; k < index->GetFeatureDim(); k++)
{
float dist = v[k] - Mean[k];
Variance[k].Dist += dist*dist;
}
}
std::sort(Variance.begin(), Variance.end(), COMMON::Compare);
std::vector<int> indexs(m_numTopDimensionTPTSplit);
std::vector<float> weight(m_numTopDimensionTPTSplit), bestweight(m_numTopDimensionTPTSplit);
float bestvariance = Variance[index->GetFeatureDim() - 1].Dist;
for (int i = 0; i < m_numTopDimensionTPTSplit; i++)
{
indexs[i] = Variance[index->GetFeatureDim() - 1 - i].VID;
bestweight[i] = 0;
}
bestweight[0] = 1;
float bestmean = Mean[indexs[0]];
std::vector<float> Val(count);
for (int i = 0; i < iIteration; i++)
{
float sumweight = 0;
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
{
weight[j] = float(rand() % 10000) / 5000.0f - 1.0f;
sumweight += weight[j] * weight[j];
}
sumweight = sqrt(sumweight);
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
{
weight[j] /= sumweight;
}
float mean = 0;
for (int j = 0; j < count; j++)
{
Val[j] = 0;
const T* v = (const T*)index->GetSample(indices[first + j]);
for (int k = 0; k < m_numTopDimensionTPTSplit; k++)
{
Val[j] += weight[k] * v[indexs[k]];
}
mean += Val[j];
}
mean /= count;
float var = 0;
for (int j = 0; j < count; j++)
{
float dist = Val[j] - mean;
var += dist * dist;
}
if (var > bestvariance)
{
bestvariance = var;
bestmean = mean;
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
{
bestweight[j] = weight[j];
}
}
}
int i = first;
int j = last;
// decide which child one point belongs
while (i <= j)
{
float val = 0;
const T* v = (const T*)index->GetSample(indices[i]);
for (int k = 0; k < m_numTopDimensionTPTSplit; k++)
{
val += bestweight[k] * v[indexs[k]];
}
if (val < bestmean)
{
i++;
}
else
{
std::swap(indices[i], indices[j]);
j--;
}
}
// if all the points in the node are equal,equally split the node into 2
if ((i == first) || (i == last + 1))
{
i = (first + last + 1) / 2;
}
Mean.clear();
Variance.clear();
Val.clear();
indexs.clear();
weight.clear();
bestweight.clear();
PartitionByTptree<T>(index, indices, first, i - 1, leaves);
PartitionByTptree<T>(index, indices, i, last, leaves);
}
}
bool LoadGraph(std::string sGraphFilename)
{
std::cout << "Load Graph From " << sGraphFilename << std::endl;
FILE * fp = fopen(sGraphFilename.c_str(), "rb");
if (fp == NULL) return false;
fread(&m_iGraphSize, sizeof(int), 1, fp);
fread(&m_iNeighborhoodSize, sizeof(int), 1, fp);
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
m_dataUpdateLock.resize(m_iGraphSize);
for (int i = 0; i < m_iGraphSize; i++)
{
fread((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
}
fclose(fp);
std::cout << "Load Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
return true;
}
bool LoadGraphFromMemory(char* pGraphMemFile)
{
m_iGraphSize = *((int*)pGraphMemFile);
pGraphMemFile += sizeof(int);
m_iNeighborhoodSize = *((int*)pGraphMemFile);
pGraphMemFile += sizeof(int);
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize, (int*)pGraphMemFile);
m_dataUpdateLock.resize(m_iGraphSize);
return true;
}
bool SaveGraph(std::string sGraphFilename) const
{
std::cout << "Save Graph To " << sGraphFilename << std::endl;
FILE *fp = fopen(sGraphFilename.c_str(), "wb");
if (fp == NULL) return false;
fwrite(&m_iGraphSize, sizeof(int), 1, fp);
fwrite(&m_iNeighborhoodSize, sizeof(int), 1, fp);
for (int i = 0; i < m_iGraphSize; i++)
{
fwrite((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
}
fclose(fp);
std::cout << "Save Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
return true;
}
bool SaveGraphToMemory(void **pGraphMemFile, int64_t &len) {
size_t size = sizeof(int) + sizeof(int) + sizeof(int) * m_iNeighborhoodSize * m_iGraphSize;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
auto ptr = mem;
*(int*)ptr = m_iGraphSize;
ptr += sizeof(int);
*(int*)ptr = m_iNeighborhoodSize;
ptr += sizeof(int);
for (int i = 0; i < m_iGraphSize; i++)
{
memcpy(ptr, (m_pNeighborhoodGraph)[i], sizeof(int) * m_iNeighborhoodSize);
ptr += sizeof(int) * m_iNeighborhoodSize;
}
*pGraphMemFile = mem;
len = size;
return true;
}
inline void AddBatch(int num) { m_pNeighborhoodGraph.AddBatch(num); m_iGraphSize += num; m_dataUpdateLock.resize(m_iGraphSize); }
inline int* operator[](int index) { return m_pNeighborhoodGraph[index]; }
inline const int* operator[](int index) const { return m_pNeighborhoodGraph[index]; }
inline void SetR(int rows) { m_pNeighborhoodGraph.SetR(rows); m_iGraphSize = rows; m_dataUpdateLock.resize(m_iGraphSize); }
inline int R() const { return m_iGraphSize; }
static std::shared_ptr<NeighborhoodGraph> CreateInstance(std::string type);
protected:
// Graph structure
int m_iGraphSize;
COMMON::Dataset<int> m_pNeighborhoodGraph;
COMMON::FineGrainedLock m_dataUpdateLock; // protect one row of the graph
public:
int m_iTPTNumber, m_iTPTLeafSize, m_iSamples, m_numTopDimensionTPTSplit;
int m_iNeighborhoodSize, m_iNeighborhoodScale, m_iCEFScale, m_iRefineIter, m_iCEF, m_iMaxCheckForRefineGraph;
};
}
}
#endif
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_QUERYRESULTSET_H_
#define _SPTAG_COMMON_QUERYRESULTSET_H_
#include "../SearchQuery.h"
namespace SPTAG
{
namespace COMMON
{
inline bool operator < (const BasicResult& lhs, const BasicResult& rhs)
{
return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
}
inline bool Compare(const BasicResult& lhs, const BasicResult& rhs)
{
return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
}
// Space to save temporary answer, similar with TopKCache
template<typename T>
class QueryResultSet : public QueryResult
{
public:
QueryResultSet(const T* _target, int _K) : QueryResult(_target, _K, false)
{
}
QueryResultSet(const QueryResultSet& other) : QueryResult(other)
{
}
inline void SetTarget(const T *p_target)
{
m_target = p_target;
}
inline const T* GetTarget() const
{
return reinterpret_cast<const T*>(m_target);
}
inline float worstDist() const
{
return m_results[0].Dist;
}
bool AddPoint(const int index, float dist)
{
if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID))
{
m_results[0].VID = index;
m_results[0].Dist = dist;
Heapify(m_resultNum);
return true;
}
return false;
}
inline void SortResult()
{
for (int i = m_resultNum - 1; i >= 0; i--)
{
std::swap(m_results[0], m_results[i]);
Heapify(i);
}
}
private:
void Heapify(int count)
{
int parent = 0, next = 1, maxidx = count - 1;
while (next < maxidx)
{
if (m_results[next] < m_results[next + 1]) next++;
if (m_results[parent] < m_results[next])
{
std::swap(m_results[next], m_results[parent]);
parent = next;
next = (parent << 1) + 1;
}
else break;
}
if (next == maxidx && m_results[parent] < m_results[next]) std::swap(m_results[parent], m_results[next]);
}
};
}
}
#endif // _SPTAG_COMMON_QUERYRESULTSET_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_RNG_H_
#define _SPTAG_COMMON_RNG_H_
#include "NeighborhoodGraph.h"
namespace SPTAG
{
namespace COMMON
{
class RelativeNeighborhoodGraph: public NeighborhoodGraph
{
public:
void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) {
int count = 0;
for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) {
const BasicResult& item = queryResults[j];
if (item.VID < 0) break;
if (item.VID == node) continue;
bool good = true;
for (int k = 0; k < count; k++) {
if (index->ComputeDistance(index->GetSample(nodes[k]), index->GetSample(item.VID)) <= item.Dist) {
good = false;
break;
}
}
if (good) nodes[count++] = item.VID;
}
for (int j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1;
}
void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist)
{
int* nodes = m_pNeighborhoodGraph[node];
for (int k = 0; k < m_iNeighborhoodSize; k++)
{
int tmpNode = nodes[k];
if (tmpNode < -1) continue;
if (tmpNode < 0)
{
bool good = true;
for (int t = 0; t < k; t++) {
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
good = false;
break;
}
}
if (good) {
nodes[k] = insertNode;
}
break;
}
float tmpDist = index->ComputeDistance(index->GetSample(node), index->GetSample(tmpNode));
if (insertDist < tmpDist || (insertDist == tmpDist && insertNode < tmpNode))
{
bool good = true;
for (int t = 0; t < k; t++) {
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
good = false;
break;
}
}
if (good) {
nodes[k] = insertNode;
insertNode = tmpNode;
insertDist = tmpDist;
}
else {
break;
}
}
}
}
float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr)
{
int* correct = new int[samples];
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < samples; i++)
{
int x = COMMON::Utils::rand_int(m_iGraphSize);
//int x = i;
COMMON::QueryResultSet<void> query(nullptr, m_iCEF);
for (int y = 0; y < m_iGraphSize; y++)
{
if ((idmap != nullptr && idmap->find(y) != idmap->end())) continue;
float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y));
query.AddPoint(y, dist);
}
query.SortResult();
int * exact_rng = new int[m_iNeighborhoodSize];
RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF);
correct[i] = 0;
for (int j = 0; j < m_iNeighborhoodSize; j++) {
if (exact_rng[j] == -1) {
correct[i] += m_iNeighborhoodSize - j;
break;
}
for (int k = 0; k < m_iNeighborhoodSize; k++)
if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) {
correct[i]++;
break;
}
}
delete[] exact_rng;
}
float acc = 0;
for (int i = 0; i < samples; i++) acc += float(correct[i]);
acc = acc / samples / m_iNeighborhoodSize;
delete[] correct;
return acc;
}
};
}
}
#endif
\ No newline at end of file
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_WORKSPACE_H_
#define _SPTAG_COMMON_WORKSPACE_H_
#include "CommonUtils.h"
#include "Heap.h"
namespace SPTAG
{
namespace COMMON
{
// node type in the priority queue
struct HeapCell
{
int node;
float distance;
HeapCell(int _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {}
inline bool operator < (const HeapCell& rhs)
{
return distance < rhs.distance;
}
inline bool operator > (const HeapCell& rhs)
{
return distance > rhs.distance;
}
};
class OptHashPosVector
{
protected:
// Max loop number in one hash block.
static const int m_maxLoop = 8;
// Max pool size.
static const int m_poolSize = 8191;
// Could we use the second hash block.
bool m_secondHash;
// Record 2 hash tables.
// [0~m_poolSize + 1) is the first block.
// [m_poolSize + 1, 2*(m_poolSize + 1)) is the second block;
int m_hashTable[(m_poolSize + 1) * 2];
inline unsigned hash_func2(int idx, int loop)
{
return ((unsigned)idx + loop) & m_poolSize;
}
inline unsigned hash_func(unsigned idx)
{
return ((unsigned)(idx * 99991) + _rotl(idx, 2) + 101) & m_poolSize;
}
public:
OptHashPosVector() {}
~OptHashPosVector() {}
void Init(int size)
{
m_secondHash = true;
clear();
}
void clear()
{
if (!m_secondHash)
{
// Clear first block.
memset(&m_hashTable[0], 0, sizeof(int)*(m_poolSize + 1));
}
else
{
// Clear all blocks.
memset(&m_hashTable[0], 0, 2 * sizeof(int) * (m_poolSize + 1));
m_secondHash = false;
}
}
inline bool CheckAndSet(int idx)
{
// Inner Index is begin from 1
return _CheckAndSet(&m_hashTable[0], idx + 1) == 0;
}
inline int _CheckAndSet(int* hashTable, int idx)
{
unsigned index, loop;
// Get first hash position.
index = hash_func(idx);
for (loop = 0; loop < m_maxLoop; ++loop)
{
if (!hashTable[index])
{
// index first match and record it.
hashTable[index] = idx;
return 1;
}
if (hashTable[index] == idx)
{
// Hit this item in hash table.
return 0;
}
// Get next hash position.
index = hash_func2(index, loop);
}
if (hashTable == &m_hashTable[0])
{
// Use second hash block.
m_secondHash = true;
return _CheckAndSet(&m_hashTable[m_poolSize + 1], idx);
}
// Do not include this item.
return -1;
}
};
// Variables for each single NN search
struct WorkSpace
{
void Initialize(int maxCheck, int dataSize)
{
nodeCheckStatus.Init(dataSize);
m_SPTQueue.Resize(maxCheck * 10);
m_NGQueue.Resize(maxCheck * 30);
m_iNumberOfTreeCheckedLeaves = 0;
m_iNumberOfCheckedLeaves = 0;
m_iContinuousLimit = maxCheck / 64;
m_iMaxCheck = maxCheck;
m_iNumOfContinuousNoBetterPropagation = 0;
}
void Reset(int maxCheck)
{
nodeCheckStatus.clear();
m_SPTQueue.clear();
m_NGQueue.clear();
m_iNumberOfTreeCheckedLeaves = 0;
m_iNumberOfCheckedLeaves = 0;
m_iContinuousLimit = maxCheck / 64;
m_iMaxCheck = maxCheck;
m_iNumOfContinuousNoBetterPropagation = 0;
}
inline bool CheckAndSet(int idx)
{
return nodeCheckStatus.CheckAndSet(idx);
}
OptHashPosVector nodeCheckStatus;
//OptHashPosVector nodeCheckStatus;
// counter for dynamic pivoting
int m_iNumOfContinuousNoBetterPropagation;
int m_iContinuousLimit;
int m_iNumberOfTreeCheckedLeaves;
int m_iNumberOfCheckedLeaves;
int m_iMaxCheck;
// Prioriy queue used for neighborhood graph
Heap<HeapCell> m_NGQueue;
// Priority queue Used for BKT-Tree
Heap<HeapCell> m_SPTQueue;
};
}
}
#endif // _SPTAG_COMMON_WORKSPACE_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_WORKSPACEPOOL_H_
#define _SPTAG_COMMON_WORKSPACEPOOL_H_
#include "WorkSpace.h"
#include <list>
#include <mutex>
namespace SPTAG
{
namespace COMMON
{
class WorkSpacePool
{
public:
WorkSpacePool(int p_maxCheck, int p_vectorCount);
virtual ~WorkSpacePool();
std::shared_ptr<WorkSpace> Rent();
void Return(const std::shared_ptr<WorkSpace>& p_workSpace);
void Init(int size);
private:
std::list<std::shared_ptr<WorkSpace>> m_workSpacePool;
std::mutex m_workSpacePoolMutex;
int m_maxCheck;
int m_vectorCount;
};
}
}
#endif // _SPTAG_COMMON_WORKSPACEPOOL_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMONDATASTRUCTURE_H_
#define _SPTAG_COMMONDATASTRUCTURE_H_
#include "Common.h"
namespace SPTAG
{
class ByteArray
{
public:
ByteArray();
ByteArray(ByteArray&& p_right);
ByteArray(std::uint8_t* p_array, std::size_t p_length, bool p_transferOnwership);
ByteArray(std::uint8_t* p_array, std::size_t p_length, std::shared_ptr<std::uint8_t> p_dataHolder);
ByteArray(const ByteArray& p_right);
ByteArray& operator= (const ByteArray& p_right);
ByteArray& operator= (ByteArray&& p_right);
~ByteArray();
static ByteArray Alloc(std::size_t p_length);
std::uint8_t* Data() const;
std::size_t Length() const;
void SetData(std::uint8_t* p_array, std::size_t p_length);
std::shared_ptr<std::uint8_t> DataHolder() const;
void Clear();
const static ByteArray c_empty;
private:
std::uint8_t* m_data;
std::size_t m_length;
// Notice this is holding an array. Set correct deleter for this.
std::shared_ptr<std::uint8_t> m_dataHolder;
};
} // namespace SPTAG
#endif // _SPTAG_COMMONDATASTRUCTURE_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DefineVectorValueType
DefineVectorValueType(Int8, std::int8_t)
DefineVectorValueType(UInt8, std::uint8_t)
DefineVectorValueType(Int16, std::int16_t)
DefineVectorValueType(Float, float)
#endif // DefineVectorValueType
#ifdef DefineDistCalcMethod
DefineDistCalcMethod(L2)
DefineDistCalcMethod(Cosine)
#endif // DefineDistCalcMethod
#ifdef DefineErrorCode
// 0x0000 ~ 0x0FFF General Status
DefineErrorCode(Success, 0x0000)
DefineErrorCode(Fail, 0x0001)
DefineErrorCode(FailedOpenFile, 0x0002)
DefineErrorCode(FailedCreateFile, 0x0003)
DefineErrorCode(ParamNotFound, 0x0010)
DefineErrorCode(FailedParseValue, 0x0011)
// 0x1000 ~ 0x1FFF Index Build Status
// 0x2000 ~ 0x2FFF Index Serve Status
// 0x3000 ~ 0x3FFF Helper Function Status
DefineErrorCode(ReadIni_FailedParseSection, 0x3000)
DefineErrorCode(ReadIni_FailedParseParam, 0x3001)
DefineErrorCode(ReadIni_DuplicatedSection, 0x3002)
DefineErrorCode(ReadIni_DuplicatedParam, 0x3003)
// 0x4000 ~ 0x4FFF Socket Library Status
DefineErrorCode(Socket_FailedResolveEndPoint, 0x4000)
DefineErrorCode(Socket_FailedConnectToEndPoint, 0x4001)
#endif // DefineErrorCode
#ifdef DefineIndexAlgo
DefineIndexAlgo(BKT)
DefineIndexAlgo(KDT)
#endif // DefineIndexAlgo
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_KDT_INDEX_H_
#define _SPTAG_KDT_INDEX_H_
#include "../Common.h"
#include "../VectorIndex.h"
#include "../Common/CommonUtils.h"
#include "../Common/DistanceUtils.h"
#include "../Common/QueryResultSet.h"
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/KDTree.h"
#include "inc/Helper/StringConvert.h"
#include "inc/Helper/SimpleIniReader.h"
#include <functional>
#include <mutex>
#include <tbb/concurrent_unordered_set.h>
namespace SPTAG
{
namespace Helper
{
class IniReader;
}
namespace KDT
{
template<typename T>
class Index : public VectorIndex
{
private:
// data points
COMMON::Dataset<T> m_pSamples;
// KDT structures.
COMMON::KDTree m_pTrees;
// Graph structure
COMMON::RelativeNeighborhoodGraph m_pGraph;
std::string m_sKDTFilename;
std::string m_sGraphFilename;
std::string m_sDataPointsFilename;
std::mutex m_dataLock; // protect data and graph
tbb::concurrent_unordered_set<int> m_deletedID;
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
int m_iMaxCheck;
int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
int m_iNumberOfInitialDynamicPivots;
int m_iNumberOfOtherDynamicPivots;
public:
Index()
{
#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \
VarName = DefaultValue; \
#include "inc/Core/KDT/ParameterDefinitionList.h"
#undef DefineKDTParameter
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
}
~Index() {}
inline int GetNumSamples() const { return m_pSamples.R(); }
inline int GetFeatureDim() const { return m_pSamples.C(); }
inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
inline int GetNumThreads() const { return m_iNumberOfThreads; }
inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::KDT; }
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; }
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen);
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs);
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout);
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader);
ErrorCode SearchIndex(QueryResult &p_query) const;
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const;
private:
ErrorCode RefineIndex(const std::string& p_folderPath);
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const;
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
};
} // namespace KDT
} // namespace SPTAG
#endif // _SPTAG_KDT_INDEX_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DefineKDTParameter
// DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr)
DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
DefineKDTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
DefineKDTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
DefineKDTParameter(m_pTrees.m_iTreeNumber, int, 1L, "KDTNumber")
DefineKDTParameter(m_pTrees.m_numTopDimensionKDTSplit, int, 5L, "NumTopDimensionKDTSplit")
DefineKDTParameter(m_pTrees.m_iSamples, int, 100L, "NumSamplesKDTSplitConsideration")
DefineKDTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber")
DefineKDTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
DefineKDTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTPTSplit")
DefineKDTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize")
DefineKDTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
DefineKDTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
DefineKDTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
DefineKDTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF")
DefineKDTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph")
DefineKDTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
DefineKDTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
DefineKDTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
DefineKDTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
DefineKDTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
DefineKDTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots")
#endif
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_METADATASET_H_
#define _SPTAG_METADATASET_H_
#include "CommonDataStructure.h"
#include <iostream>
#include <fstream>
namespace SPTAG
{
class MetadataSet
{
public:
MetadataSet();
virtual ~MetadataSet();
virtual ByteArray GetMetadata(IndexType p_vectorID) const = 0;
virtual SizeType Count() const = 0;
virtual bool Available() const = 0;
virtual void AddBatch(MetadataSet& data) = 0;
virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0;
virtual ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len) = 0;
virtual ErrorCode LoadMetadataFromMemory(void *pGraphMemFile) = 0;
virtual ErrorCode RefineMetadata(std::vector<int>& indices, const std::string& p_folderPath);
static ErrorCode MetaCopy(const std::string& p_src, const std::string& p_dst);
};
class FileMetadataSet : public MetadataSet
{
public:
FileMetadataSet(const std::string& p_metaFile, const std::string& p_metaindexFile);
~FileMetadataSet();
ByteArray GetMetadata(IndexType p_vectorID) const;
SizeType Count() const;
bool Available() const;
void AddBatch(MetadataSet& data);
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len);
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
private:
std::ifstream* m_fp = nullptr;
std::vector<std::uint64_t> m_pOffsets;
SizeType m_count;
std::string m_metaFile;
std::string m_metaindexFile;
std::vector<std::uint8_t> m_newdata;
};
class MemMetadataSet : public MetadataSet
{
public:
MemMetadataSet() = default;
MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count);
~MemMetadataSet();
ByteArray GetMetadata(IndexType p_vectorID) const;
SizeType Count() const;
bool Available() const;
void AddBatch(MetadataSet& data);
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len);
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
private:
std::vector<std::uint64_t> m_offsets;
SizeType m_count;
ByteArray m_metadataHolder;
ByteArray m_offsetHolder;
std::vector<std::uint8_t> m_newdata;
};
} // namespace SPTAG
#endif // _SPTAG_METADATASET_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SEARCHQUERY_H_
#define _SPTAG_SEARCHQUERY_H_
#include "CommonDataStructure.h"
#include <cstring>
namespace SPTAG
{
struct BasicResult
{
int VID;
float Dist;
BasicResult() : VID(-1), Dist(MaxDist) {}
BasicResult(int p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {}
};
// Space to save temporary answer, similar with TopKCache
class QueryResult
{
public:
typedef BasicResult* iterator;
typedef const BasicResult* const_iterator;
QueryResult()
: m_target(nullptr),
m_resultNum(0),
m_withMeta(false)
{
}
QueryResult(const void* p_target, int p_resultNum, bool p_withMeta)
: m_target(nullptr),
m_resultNum(0),
m_withMeta(false)
{
Init(p_target, p_resultNum, p_withMeta);
}
QueryResult(const void* p_target, int p_resultNum, std::vector<BasicResult>& p_results)
: m_target(p_target),
m_resultNum(p_resultNum),
m_withMeta(false)
{
p_results.resize(p_resultNum);
m_results.reset(p_results.data());
}
QueryResult(const QueryResult& p_other)
: m_target(p_other.m_target),
m_resultNum(p_other.m_resultNum),
m_withMeta(p_other.m_withMeta)
{
if (m_resultNum > 0)
{
m_results.reset(new BasicResult[m_resultNum]);
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum);
if (m_withMeta)
{
m_metadatas.reset(new ByteArray[m_resultNum]);
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
}
}
}
QueryResult& operator=(const QueryResult& p_other)
{
Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta);
if (m_resultNum > 0)
{
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum);
if (m_withMeta)
{
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
}
}
return *this;
}
~QueryResult()
{
}
inline void Init(const void* p_target, int p_resultNum, bool p_withMeta)
{
m_target = p_target;
if (p_resultNum > m_resultNum)
{
m_results.reset(new BasicResult[p_resultNum]);
}
if (p_withMeta && (!m_withMeta || p_resultNum > m_resultNum))
{
m_metadatas.reset(new ByteArray[p_resultNum]);
}
m_resultNum = p_resultNum;
m_withMeta = p_withMeta;
}
inline int GetResultNum() const
{
return m_resultNum;
}
inline const void* GetTarget()
{
return m_target;
}
inline void SetTarget(const void* p_target)
{
m_target = p_target;
}
inline BasicResult* GetResult(int i) const
{
return i < m_resultNum ? m_results.get() + i : nullptr;
}
inline void SetResult(int p_index, int p_VID, float p_dist)
{
if (p_index < m_resultNum)
{
m_results[p_index].VID = p_VID;
m_results[p_index].Dist = p_dist;
}
}
inline BasicResult* GetResults() const
{
return m_results.get();
}
inline bool WithMeta() const
{
return m_withMeta;
}
inline const ByteArray& GetMetadata(int p_index) const
{
if (p_index < m_resultNum && m_withMeta)
{
return m_metadatas[p_index];
}
return ByteArray::c_empty;
}
inline void SetMetadata(int p_index, ByteArray p_metadata)
{
if (p_index < m_resultNum && m_withMeta)
{
m_metadatas[p_index] = std::move(p_metadata);
}
}
inline void Reset()
{
for (int i = 0; i < m_resultNum; i++)
{
m_results[i].VID = -1;
m_results[i].Dist = MaxDist;
}
if (m_withMeta)
{
for (int i = 0; i < m_resultNum; i++)
{
m_metadatas[i].Clear();
}
}
}
iterator begin()
{
return m_results.get();
}
iterator end()
{
return m_results.get() + m_resultNum;
}
const_iterator begin() const
{
return m_results.get();
}
const_iterator end() const
{
return m_results.get() + m_resultNum;
}
protected:
const void* m_target;
int m_resultNum;
bool m_withMeta;
std::unique_ptr<BasicResult[]> m_results;
std::unique_ptr<ByteArray[]> m_metadatas;
};
} // namespace SPTAG
#endif // _SPTAG_SEARCHQUERY_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_VECTORINDEX_H_
#define _SPTAG_VECTORINDEX_H_
#include "Common.h"
#include "SearchQuery.h"
#include "VectorSet.h"
#include "MetadataSet.h"
#include "inc/Helper/SimpleIniReader.h"
namespace SPTAG
{
class VectorIndex
{
public:
VectorIndex();
virtual ~VectorIndex();
virtual ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout) = 0;
virtual ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader) = 0;
virtual ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen) = 0;
virtual ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs) = 0;
virtual ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) = 0;
virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0;
virtual ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) = 0;
virtual ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum) = 0;
//virtual ErrorCode AddIndexWithID(const void* p_vector, const int& p_id) = 0;
//virtual ErrorCode DeleteIndexWithID(const void* p_vector, const int& p_id) = 0;
virtual float ComputeDistance(const void* pX, const void* pY) const = 0;
virtual const void* GetSample(const int idx) const = 0;
virtual int GetFeatureDim() const = 0;
virtual int GetNumSamples() const = 0;
virtual DistCalcMethod GetDistCalcMethod() const = 0;
virtual IndexAlgoType GetIndexAlgoType() const = 0;
virtual VectorValueType GetVectorValueType() const = 0;
virtual int GetNumThreads() const = 0;
virtual std::string GetParameter(const char* p_param) const = 0;
virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0;
virtual ErrorCode LoadIndex(const std::string& p_folderPath);
virtual ErrorCode SaveIndex(const std::string& p_folderPath);
virtual ErrorCode BuildIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet);
virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, std::vector<BasicResult>& p_results) const;
virtual ErrorCode AddIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet);
virtual std::string GetParameter(const std::string& p_param) const;
virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value);
virtual ByteArray GetMetadata(IndexType p_vectorID) const;
virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath);
virtual std::string GetIndexName() const
{
if (m_sIndexName == "")
return Helper::Convert::ConvertToString(GetIndexAlgoType());
return m_sIndexName;
}
virtual void SetIndexName(std::string p_name) { m_sIndexName = p_name; }
static std::shared_ptr<VectorIndex> CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype);
static ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2);
static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr<VectorIndex>& p_vectorIndex);
protected:
std::string m_sIndexName;
std::shared_ptr<MetadataSet> m_pMetadata;
};
} // namespace SPTAG
#endif // _SPTAG_VECTORINDEX_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_VECTORSET_H_
#define _SPTAG_VECTORSET_H_
#include "CommonDataStructure.h"
namespace SPTAG
{
class VectorSet
{
public:
VectorSet();
virtual ~VectorSet();
virtual VectorValueType GetValueType() const = 0;
virtual void* GetVector(IndexType p_vectorID) const = 0;
virtual void* GetData() const = 0;
virtual SizeType Dimension() const = 0;
virtual SizeType Count() const = 0;
virtual bool Available() const = 0;
virtual ErrorCode Save(const std::string& p_vectorFile) const = 0;
};
class BasicVectorSet : public VectorSet
{
public:
BasicVectorSet(const ByteArray& p_bytesArray,
VectorValueType p_valueType,
SizeType p_dimension,
SizeType p_vectorCount);
virtual ~BasicVectorSet();
virtual VectorValueType GetValueType() const;
virtual void* GetVector(IndexType p_vectorID) const;
virtual void* GetData() const;
virtual SizeType Dimension() const;
virtual SizeType Count() const;
virtual bool Available() const;
virtual ErrorCode Save(const std::string& p_vectorFile) const;
private:
ByteArray m_data;
VectorValueType m_valueType;
SizeType m_dimension;
SizeType m_vectorCount;
SizeType m_perVectorDataSize;
};
} // namespace SPTAG
#endif // _SPTAG_VECTORSET_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_ARGUMENTSPARSER_H_
#define _SPTAG_HELPER_ARGUMENTSPARSER_H_
#include "inc/Helper/StringConvert.h"
#include <cstdint>
#include <cstddef>
#include <memory>
#include <vector>
#include <string>
namespace SPTAG
{
namespace Helper
{
class ArgumentsParser
{
public:
ArgumentsParser();
virtual ~ArgumentsParser();
virtual bool Parse(int p_argc, char** p_args);
virtual void PrintHelp();
protected:
class IArgument
{
public:
IArgument();
virtual ~IArgument();
virtual bool ParseValue(int& p_restArgc, char** (&p_args)) = 0;
virtual void PrintDescription(FILE* p_output) = 0;
virtual bool IsRequiredButNotSet() const = 0;
};
template<typename DataType>
class ArgumentT : public IArgument
{
public:
ArgumentT(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description,
bool p_followedValue,
const DataType& p_switchAsValue,
bool p_isRequired)
: m_value(p_target),
m_representStringShort(p_representStringShort),
m_representString(p_representString),
m_description(p_description),
m_followedValue(p_followedValue),
c_switchAsValue(p_switchAsValue),
m_isRequired(p_isRequired),
m_isSet(false)
{
}
virtual ~ArgumentT()
{
}
virtual bool ParseValue(int& p_restArgc, char** (&p_args))
{
if (0 == p_restArgc)
{
return true;
}
if (0 != strcmp(*p_args, m_representString.c_str())
&& 0 != strcmp(*p_args, m_representStringShort.c_str()))
{
return true;
}
if (!m_followedValue)
{
m_value = c_switchAsValue;
--p_restArgc;
++p_args;
m_isSet = true;
return true;
}
if (p_restArgc < 2)
{
return false;
}
DataType tmp;
if (!Helper::Convert::ConvertStringTo(p_args[1], tmp))
{
return false;
}
m_value = std::move(tmp);
p_restArgc -= 2;
p_args += 2;
m_isSet = true;
return true;
}
virtual void PrintDescription(FILE* p_output)
{
std::size_t padding = 30;
if (!m_representStringShort.empty())
{
fprintf(p_output, "%s", m_representStringShort.c_str());
padding -= m_representStringShort.size();
}
if (!m_representString.empty())
{
if (!m_representStringShort.empty())
{
fprintf(p_output, ", ");
padding -= 2;
}
fprintf(p_output, "%s", m_representString.c_str());
padding -= m_representString.size();
}
if (m_followedValue)
{
fprintf(p_output, " <value>");
padding -= 8;
}
while (padding-- > 0)
{
fputc(' ', p_output);
}
fprintf(p_output, "%s", m_description.c_str());
}
virtual bool IsRequiredButNotSet() const
{
return m_isRequired && !m_isSet;
}
private:
DataType & m_value;
std::string m_representStringShort;
std::string m_representString;
std::string m_description;
bool m_followedValue;
const DataType c_switchAsValue;
bool m_isRequired;
bool m_isSet;
};
template<typename DataType>
void AddRequiredOption(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
true,
DataType(),
true)));
}
template<typename DataType>
void AddOptionalOption(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
true,
DataType(),
false)));
}
template<typename DataType>
void AddRequiredSwitch(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description,
const DataType& p_switchAsValue)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
false,
p_switchAsValue,
true)));
}
template<typename DataType>
void AddOptionalSwitch(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description,
const DataType& p_switchAsValue)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
false,
p_switchAsValue,
false)));
}
private:
std::vector<std::shared_ptr<IArgument>> m_arguments;
};
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_ARGUMENTSPARSER_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_BASE64ENCODE_H_
#define _SPTAG_HELPER_BASE64ENCODE_H_
#include <cstdint>
#include <cstddef>
#include <ostream>
namespace SPTAG
{
namespace Helper
{
namespace Base64
{
bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, char* p_out, std::size_t& p_outLen);
bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, std::ostream& p_out, std::size_t& p_outLen);
bool Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std::size_t& p_outLen);
std::size_t CapacityForEncode(std::size_t p_inLen);
std::size_t CapacityForDecode(std::size_t p_inLen);
} // namespace Base64
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_BASE64ENCODE_H_
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_COMMONHELPER_H_
#define _SPTAG_HELPER_COMMONHELPER_H_
#include "../Core/Common.h"
#include <string>
#include <vector>
#include <cctype>
#include <functional>
#include <limits>
#include <cerrno>
namespace SPTAG
{
namespace Helper
{
namespace StrUtils
{
void ToLowerInPlace(std::string& p_str);
std::vector<std::string> SplitString(const std::string& p_str, const std::string& p_separator);
std::pair<const char*, const char*> FindTrimmedSegment(const char* p_begin,
const char* p_end,
const std::function<bool(char)>& p_isSkippedChar);
bool StartsWith(const char* p_str, const char* p_prefix);
bool StrEqualIgnoreCase(const char* p_left, const char* p_right);
} // namespace StrUtils
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_COMMONHELPER_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册