Index.h 5.5 KB
Newer Older
X
xj.lin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifndef _SPTAG_KDT_INDEX_H_
#define _SPTAG_KDT_INDEX_H_

#include "../Common.h"
#include "../VectorIndex.h"

#include "../Common/CommonUtils.h"
#include "../Common/DistanceUtils.h"
#include "../Common/QueryResultSet.h"
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/KDTree.h"
X
xiaojun.lin 已提交
18
#include "inc/Helper/ConcurrentSet.h"
X
xj.lin 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#include "inc/Helper/StringConvert.h"
#include "inc/Helper/SimpleIniReader.h"

#include <functional>
#include <mutex>

namespace SPTAG
{

    namespace Helper
    {
        class IniReader;
    }

    namespace KDT
    {
        template<typename T>
        class Index : public VectorIndex
        {
        private:
            // data points
            COMMON::Dataset<T> m_pSamples;

            // KDT structures. 
            COMMON::KDTree m_pTrees;

            // Graph structure
            COMMON::RelativeNeighborhoodGraph m_pGraph;

            std::string m_sKDTFilename;
            std::string m_sGraphFilename;
            std::string m_sDataPointsFilename;
X
xiaojun.lin 已提交
51
            std::string m_sDeleteDataPointsFilename;
X
xj.lin 已提交
52

X
xiaojun.lin 已提交
53 54 55
            std::mutex m_dataAddLock; // protect data and graph
            Helper::Concurrent::ConcurrentSet<SizeType> m_deletedID;
            float m_fDeletePercentageForRefine;
X
xj.lin 已提交
56 57 58 59
            std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
            
            int m_iNumberOfThreads;
            DistCalcMethod m_iDistCalcMethod;
X
xiaojun.lin 已提交
60
            float(*m_fComputeDistance)(const T* pX, const T* pY, DimensionType length);
X
xj.lin 已提交
61 62 63 64 65 66 67
 
            int m_iMaxCheck;
            int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
            int m_iNumberOfInitialDynamicPivots;
            int m_iNumberOfOtherDynamicPivots;
        public:
            Index()
X
xiaojun.lin 已提交
68
            {
X
xj.lin 已提交
69 70 71 72 73
#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \
                VarName = DefaultValue; \

#include "inc/Core/KDT/ParameterDefinitionList.h"
#undef DefineKDTParameter
X
xiaojun.lin 已提交
74 75 76 77
                
                m_pSamples.SetName("Vector");
                m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
            }
X
xj.lin 已提交
78 79 80

            ~Index() {}

X
xiaojun.lin 已提交
81
            inline SizeType GetNumSamples() const { return m_pSamples.R(); }
S
shengjun.li 已提交
82
            inline SizeType GetIndexSize() const { return sizeof(*this); }
X
xiaojun.lin 已提交
83
            inline DimensionType GetFeatureDim() const { return m_pSamples.C(); }
X
xj.lin 已提交
84 85 86 87 88 89 90 91
            
            inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
            inline int GetNumThreads() const { return m_iNumberOfThreads; }
            inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
            inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::KDT; }
            inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
            
            inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
X
xiaojun.lin 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
            inline const void* GetSample(const SizeType idx) const { return (void*)m_pSamples[idx]; }
            inline bool ContainSample(const SizeType idx) const { return !m_deletedID.contains(idx); }
            inline bool NeedRefine() const { return m_deletedID.size() >= (size_t)(GetNumSamples() * m_fDeletePercentageForRefine); }
            std::shared_ptr<std::vector<std::uint64_t>> BufferSize() const
            {
                std::shared_ptr<std::vector<std::uint64_t>> buffersize(new std::vector<std::uint64_t>);
                buffersize->push_back(m_pSamples.BufferSize());
                buffersize->push_back(m_pTrees.BufferSize());
                buffersize->push_back(m_pGraph.BufferSize());
                buffersize->push_back(m_deletedID.bufferSize());
                return std::move(buffersize);
            }

            ErrorCode SaveConfig(std::ostream& p_configout) const;
            ErrorCode SaveIndexData(const std::string& p_folderPath);
            ErrorCode SaveIndexData(const std::vector<std::ostream*>& p_indexStreams);

            ErrorCode LoadConfig(Helper::IniReader& p_reader);
            ErrorCode LoadIndexData(const std::string& p_folderPath);
            ErrorCode LoadIndexDataFromMemory(const std::vector<ByteArray>& p_indexBlobs);

            ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension);
X
xj.lin 已提交
114
            ErrorCode SearchIndex(QueryResult &p_query) const;
X
xiaojun.lin 已提交
115 116 117
            ErrorCode AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start = nullptr);
            ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum);
            ErrorCode DeleteIndex(const SizeType& p_id);
X
xj.lin 已提交
118 119 120 121 122

            ErrorCode SetParameter(const char* p_param, const char* p_value);
            std::string GetParameter(const char* p_param) const;

            ErrorCode RefineIndex(const std::string& p_folderPath);
X
xiaojun.lin 已提交
123 124 125 126
            ErrorCode RefineIndex(const std::vector<std::ostream*>& p_indexStreams);

        private:
            void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet<SizeType> &p_deleted) const;
X
xj.lin 已提交
127 128 129 130 131 132
            void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
        };
    } // namespace KDT
} // namespace SPTAG

#endif // _SPTAG_KDT_INDEX_H_