Clustering.h 5.0 KB
Newer Older
J
JinHai-CN 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// -*- c++ -*-

#ifndef FAISS_CLUSTERING_H
#define FAISS_CLUSTERING_H
#include <faiss/Index.h>

#include <vector>

namespace faiss {

T
TheBloodthirster 已提交
18
/**
19
 * The algorithm of clustering
T
TheBloodthirster 已提交
20
 */
21
enum ClusteringType
T
TheBloodthirster 已提交
22
{
23 24 25
    K_MEANS,
    K_MEANS_PLUS_PLUS,
    K_MEANS_TWO,
T
TheBloodthirster 已提交
26 27
};

28 29
//The default algorithm use the K_MEANS
extern ClusteringType clustering_type;
T
TheBloodthirster 已提交
30

J
JinHai-CN 已提交
31 32 33 34 35 36 37 38 39 40 41

/** Class for the clustering parameters. Can be passed to the
 * constructor of the Clustering object.
 */
struct ClusteringParameters {
    int niter;          ///< clustering iterations
    int nredo;          ///< redo clustering this many times and keep best

    bool verbose;
    bool spherical;     ///< do we want normalized centroids?
    bool int_centroids; ///< round centroids coordinates to integer
C
Cai Yudong 已提交
42
    bool update_index;  ///< re-train index after each iteration?
J
JinHai-CN 已提交
43 44 45 46 47 48 49
    bool frozen_centroids;  ///< use the centroids provided as input and do not change them during iterations

    int min_points_per_centroid; ///< otherwise you get a warning
    int max_points_per_centroid;  ///< to limit size of dataset

    int seed; ///< seed for the random number generator

C
Cai Yudong 已提交
50 51
    size_t decode_block_size;  ///< how many vectors at a time to decode

J
JinHai-CN 已提交
52 53 54 55 56
    /// sets reasonable defaults
    ClusteringParameters ();
};


C
Cai Yudong 已提交
57 58 59 60 61 62 63 64 65 66
struct ClusteringIterationStats {
    float obj;               ///< objective values (sum of distances reported by index)
    double time;             ///< seconds for iteration
    double time_search;      ///< seconds for just search
    double imbalance_factor; ///< imbalance factor of iteration
    int nsplit;              ///< number of cluster splits
};


/** K-means clustering based on assignment - centroid update iterations
J
JinHai-CN 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
 *
 * The clustering is based on an Index object that assigns training
 * points to the centroids. Therefore, at each iteration the centroids
 * are added to the index.
 *
 * On output, the centoids table is set to the latest version
 * of the centroids and they are also added to the index. If the
 * centroids table it is not empty on input, it is also used for
 * initialization.
 *
 */
struct Clustering: ClusteringParameters {
    typedef Index::idx_t idx_t;
    size_t d;              ///< dimension of the vectors
    size_t k;              ///< nb of centroids

C
Cai Yudong 已提交
83 84 85
    /** centroids (k * d)
     * if centroids are set on input to train, they will be used as initialization
     */
J
JinHai-CN 已提交
86 87
    std::vector<float> centroids;

C
Cai Yudong 已提交
88 89
    /// stats at every iteration of clustering
    std::vector<ClusteringIterationStats> iteration_stats;
J
JinHai-CN 已提交
90 91 92 93

    Clustering (int d, int k);
    Clustering (int d, int k, const ClusteringParameters &cp);

C
Cai Yudong 已提交
94 95 96 97 98 99 100 101 102
    /** run k-means training
     *
     * @param x          training vectors, size n * d
     * @param index      index used for assignment
     * @param x_weights  weight associated to each vector: NULL or size n
     */
    virtual void train (idx_t n, const float * x, faiss::Index & index,
                        const float *x_weights = nullptr);

T
TheBloodthirster 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    /**
     * @brief Kmeans algorithm
     * 
     * @param centroids_index   [out] centroids index
     * @param random_seed       seed for the random number generator
     * @param n_input_centroids the number of centroids that user input
     * @param d                 dimension
     * @param k                 number of centroids
     * @param nx                size of data
     * @param x_in              data of point
     */
    void kmeans_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
                          size_t n_input_centroids, size_t d, size_t k,
                          idx_t nx, const uint8_t *x_in);

    void kmeans_plus_plus_algorithm(std::vector<int>& centroids_index, int64_t random_seed,
                                    size_t n_input_centroids, size_t d, size_t k,
                                    idx_t nx, const uint8_t *x_in);
C
Cai Yudong 已提交
121 122 123 124 125 126 127 128 129 130 131 132

    /** run with encoded vectors
     *
     * win addition to train()'s parameters takes a codec as parameter
     * to decode the input vectors.
     *
     * @param codec      codec used to decode the vectors (nullptr =
     *                   vectors are in fact floats)     *
     */
    void train_encoded (idx_t nx, const uint8_t *x_in,
                        const Index * codec, Index & index,
                        const float *weights = nullptr);
J
JinHai-CN 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

    /// Post-process the centroids after each centroid update.
    /// includes optional L2 normalization and nearest integer rounding
    void post_process_centroids ();

    virtual ~Clustering() {}
};


/** simplified interface
 *
 * @param d dimension of the data
 * @param n nb of training vectors
 * @param k nb of output centroids
 * @param x training set (size n * d)
 * @param centroids output centroids (size k * d)
 * @return final quantization error
 */
float kmeans_clustering (size_t d, size_t n, size_t k,
                         const float *x,
                         float *centroids);



}


#endif