提交 1f4dc61d 编写于 作者: M Marius Muja

Updates

上级 90c5a42d
......@@ -221,7 +221,7 @@ public:
/**
Method that searches for nearest-neighbors
*/
virtual void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
virtual void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
{
if (searchParams.checks==-2) {
bestIndex->findNeighbors(result, vec, bestSearchParams);
......
......@@ -171,7 +171,7 @@ public:
kdtree->loadIndex(stream);
}
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
{
kmeans->findNeighbors(result,vec,searchParams);
kdtree->findNeighbors(result,vec,searchParams);
......
......@@ -298,7 +298,7 @@ public:
* vec = the vector for which to search the nearest neighbors
* maxCheck = the maximum number of restarts (in a best-bin-first manner)
*/
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
{
int maxChecks = searchParams.checks;
float epsError = 1+searchParams.eps;
......@@ -496,7 +496,7 @@ private:
* Performs an exact nearest neighbor search. The exact search performs a full
* traversal of the tree.
*/
void getExactNeighbors(ResultSet& result, const ElementType* vec, float epsError)
void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError)
{
// checkID -= 1; /* Set a different unique ID for each search. */
......@@ -514,7 +514,7 @@ private:
* because the tree traversal is abandoned after a given number of descends in
* the tree.
*/
void getNeighbors(ResultSet& result, const ElementType* vec, int maxCheck, float epsError)
void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError)
{
int i;
BranchSt branch;
......@@ -544,7 +544,7 @@ private:
* higher levels, all exemplars below this level must have a distance of
* at least "mindistsq".
*/
void searchLevel(ResultSet& result_set, const ElementType* vec, NodePtr node, float mindistsq, int& checkCount, int maxCheck,
void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, float mindistsq, int& checkCount, int maxCheck,
float epsError, Heap<BranchSt>* heap, vector<bool>& checked)
{
if (result_set.worstDist()<mindistsq) {
......@@ -600,7 +600,7 @@ private:
/**
* Performs an exact search in the tree starting from a node.
*/
void searchLevelExact(ResultSet& result_set, const ElementType* vec, const NodePtr node, float mindistsq, const float epsError)
void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, float mindistsq, const float epsError)
{
/* If this is a leaf node, then do check and return. */
if (node->child1 == NULL && node->child2 == NULL) {
......
/***********************************************************************
* Software License Agreement (BSD License)
*
* Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
* Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
*
* THE BSD LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*************************************************************************/
#ifndef KDTREE_H
#define KDTREE_H
#include <algorithm>
#include <map>
#include <cassert>
#include <cstring>
#include "flann/general.h"
#include "flann/algorithms/nn_index.h"
#include "flann/util/matrix.h"
#include "flann/util/result_set.h"
#include "flann/util/heap.h"
#include "flann/util/allocator.h"
#include "flann/util/random.h"
#include "flann/util/saving.h"
using namespace std;
namespace flann
{
struct KDTreeIndexParams : public IndexParams {
KDTreeIndexParams(int trees_ = 4) : IndexParams(KDTREE), trees(trees_) {};
int trees; // number of randomized trees to use (for kdtree)
flann_algorithm_t getIndexType() const { return algorithm; }
void fromParameters(const FLANNParameters& p)
{
assert(p.algorithm==algorithm);
trees = p.trees;
}
void toParameters(FLANNParameters& p) const
{
p.algorithm = algorithm;
p.trees = trees;
}
void print() const
{
logger.info("Index type: %d\n",(int)algorithm);
logger.info("Trees: %d\n", trees);
}
};
/**
* Randomized kd-tree index
*
* Contains the k-d trees and other information for indexing a set of points
* for nearest-neighbor matching.
*/
template <typename Distance>
class KDTreeIndex : public NNIndex<Distance>
{
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
enum {
/**
* To improve efficiency, only SAMPLE_MEAN random values are used to
* compute the mean and variance at each level when building a tree.
* A value of 100 seems to perform as well as using all values.
*/
SAMPLE_MEAN = 100,
/**
* Top random dimensions to consider
*
* When creating random trees, the dimension on which to subdivide is
* selected at random from among the top RAND_DIM dimensions with the
* highest variance. A value of 5 works well.
*/
RAND_DIM=5
};
/**
* Number of randomized trees that are used
*/
int numTrees;
/**
* Array of indices to vectors in the dataset.
*/
int* vind;
/**
* The dataset used by this index
*/
const Matrix<ElementType> dataset;
const IndexParams& index_params;
size_t size_;
size_t veclen_;
DistanceType* mean;
DistanceType* var;
/*--------------------- Internal Data Structures --------------------------*/
/**
* A node of the binary k-d tree.
*
* This is All nodes that have vec[divfeat] < divval are placed in the
* child1 subtree, else child2., A leaf node is indicated if both children are NULL.
*/
struct TreeSt {
/**
* Index of the vector feature used for subdivision.
* If this is a leaf node (both children are NULL) then
* this holds vector index for this leaf.
*/
int divfeat;
/**
* The value used for subdivision.
*/
DistanceType divval;
/**
* The child nodes.
*/
TreeSt *child1, *child2;
};
typedef TreeSt* Tree;
/**
* Array of k-d trees used to find neighbours.
*/
Tree* trees;
typedef BranchStruct<Tree> BranchSt;
typedef BranchSt* Branch;
/**
* Pooled memory allocator.
*
* Using a pooled memory allocator is more efficient
* than allocating memory directly when there is a large
* number small of memory allocations.
*/
PooledAllocator pool;
Distance distance;
public:
flann_algorithm_t getType() const
{
return KDTREE;
}
/**
* KDTree constructor
*
* Params:
* inputData = dataset with the input features
* params = parameters passed to the kdtree algorithm
*/
KDTreeIndex(const Matrix<ElementType>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams(),
Distance d = Distance()) :
dataset(inputData), index_params(params), distance(d)
{
size_ = dataset.rows;
veclen_ = dataset.cols;
numTrees = params.trees;
trees = new Tree[numTrees];
// get the parameters
// if (params.find("trees") != params.end()) {
// numTrees = (int)params["trees"];
// trees = new Tree[numTrees];
// }
// else {
// numTrees = -1;
// trees = NULL;
// }
// Create a permutable array of indices to the input vectors.
vind = new int[size_];
for (size_t i = 0; i < size_; i++) {
vind[i] = i;
}
mean = new DistanceType[veclen_];
var = new DistanceType[veclen_];
}
/**
* Standard destructor
*/
~KDTreeIndex()
{
delete[] vind;
if (trees!=NULL) {
delete[] trees;
}
delete[] mean;
delete[] var;
}
/**
* Builds the index
*/
void buildIndex()
{
/* Construct the randomized trees. */
for (int i = 0; i < numTrees; i++) {
/* Randomize the order of vectors to allow for unbiased sampling. */
for (int j = size_; j > 0; --j) {
int rnd = rand_int(j);
swap(vind[j-1], vind[rnd]);
}
trees[i] = divideTree(0, size_ - 1);
}
}
void saveIndex(FILE* stream)
{
save_value(stream, numTrees);
for (int i=0;i<numTrees;++i) {
save_tree(stream, trees[i]);
}
}
void loadIndex(FILE* stream)
{
load_value(stream, numTrees);
if (trees!=NULL) {
delete[] trees;
}
trees = new Tree[numTrees];
for (int i=0;i<numTrees;++i) {
load_tree(stream,trees[i]);
}
}
/**
* Returns size of index.
*/
size_t size() const
{
return size_;
}
/**
* Returns the length of an index feature.
*/
size_t veclen() const
{
return veclen_;
}
/**
* Computes the inde memory usage
* Returns: memory used by the index
*/
int usedMemory() const
{
return pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int); // pool memory and vind array memory
}
/**
* Find set of nearest neighbors to vec. Their indices are stored inside
* the result object.
*
* Params:
* result = the result object in which the indices of the nearest-neighbors are stored
* vec = the vector for which to search the nearest neighbors
* maxCheck = the maximum number of restarts (in a best-bin-first manner)
*/
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
{
int maxChecks = searchParams.checks;
if (maxChecks<0) {
getExactNeighbors(result, vec);
} else {
getNeighbors(result, vec, maxChecks);
}
}
const IndexParams* getParameters() const
{
return &index_params;
}
private:
void save_tree(FILE* stream, Tree tree)
{
save_value(stream, *tree);
if (tree->child1!=NULL) {
save_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
save_tree(stream, tree->child2);
}
}
void load_tree(FILE* stream, Tree& tree)
{
tree = pool.allocate<TreeSt>();
load_value(stream, *tree);
if (tree->child1!=NULL) {
load_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
load_tree(stream, tree->child2);
}
}
/**
* Create a tree node that subdivides the list of vecs from vind[first]
* to vind[last]. The routine is called recursively on each sublist.
* Place a pointer to this new tree node in the location pTree.
*
* Params: pTree = the new node to create
* first = index of the first vector
* last = index of the last vector
*/
Tree divideTree(int first, int last)
{
Tree node = pool.allocate<TreeSt>(); // allocate memory
/* If only one exemplar remains, then make this a leaf node. */
if (first == last) {
node->child1 = node->child2 = NULL; /* Mark as leaf node. */
node->divfeat = vind[first]; /* Store index of this vec. */
}
else {
chooseDivision(node, first, last);
subdivide(node, first, last);
}
return node;
}
/**
* Choose which feature to use in order to subdivide this set of vectors.
* Make a random choice among those with the highest variance, and use
* its variance as the threshold value.
*/
void chooseDivision(Tree node, int first, int last)
{
memset(mean,0,veclen_*sizeof(DistanceType));
memset(var,0,veclen_*sizeof(DistanceType));
/* Compute mean values. Only the first SAMPLE_MEAN values need to be
sampled to get a good estimate.
*/
int end = min(first + SAMPLE_MEAN, last);
for (int j = first; j <= end; ++j) {
ElementType* v = dataset[vind[j]];
for (size_t k=0; k<veclen_; ++k) {
mean[k] += v[k];
}
}
for (size_t k=0; k<veclen_; ++k) {
mean[k] /= (end - first + 1);
}
/* Compute variances (no need to divide by count). */
for (int j = first; j <= end; ++j) {
ElementType* v = dataset[vind[j]];
for (size_t k=0; k<veclen_; ++k) {
DistanceType dist = v[k] - mean[k];
var[k] += dist * dist;
}
}
/* Select one of the highest variance indices at random. */
node->divfeat = selectDivision(var);
node->divval = mean[node->divfeat];
}
/**
* Select the top RAND_DIM largest values from v and return the index of
* one of these selected at random.
*/
int selectDivision(DistanceType* v)
{
int num = 0;
int topind[RAND_DIM];
/* Create a list of the indices of the top RAND_DIM values. */
for (size_t i = 0; i < veclen_; ++i) {
if (num < RAND_DIM || v[i] > v[topind[num-1]]) {
/* Put this element at end of topind. */
if (num < RAND_DIM) {
topind[num++] = i; /* Add to list. */
}
else {
topind[num-1] = i; /* Replace last element. */
}
/* Bubble end value down to right location by repeated swapping. */
int j = num - 1;
while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
swap(topind[j], topind[j-1]);
--j;
}
}
}
/* Select a random integer in range [0,num-1], and return that index. */
int rnd = rand_int(num);
return topind[rnd];
}
/**
* Subdivide the list of exemplars using the feature and division
* value given in this node. Call divideTree recursively on each list.
*/
void subdivide(Tree node, int first, int last)
{
/* Move vector indices for left subtree to front of list. */
int i = first;
int j = last;
while (i <= j) {
int ind = vind[i];
ElementType val = dataset[ind][node->divfeat];
if (val < node->divval) {
++i;
} else {
/* Move to end of list by swapping vind i and j. */
swap(vind[i], vind[j]);
--j;
}
}
/* If either list is empty, it means we have hit the unlikely case
in which all remaining features are identical. Split in the middle
to maintain a balanced tree.
*/
if ( (i == first) || (i == last+1)) {
i = (first+last+1)/2;
}
node->child1 = divideTree(first, i - 1);
node->child2 = divideTree(i, last);
}
/**
* Performs an exact nearest neighbor search. The exact search performs a full
* traversal of the tree.
*/
void getExactNeighbors(ResultSet& result, const ElementType* vec)
{
// checkID -= 1; /* Set a different unique ID for each search. */
if (numTrees > 1) {
fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
}
if (numTrees>0) {
searchLevelExact(result, vec, trees[0], 0.0);
}
assert(result.full());
}
/**
* Performs the approximate nearest-neighbor search. The search is approximate
* because the tree traversal is abandoned after a given number of descends in
* the tree.
*/
void getNeighbors(ResultSet& result, const ElementType* vec, int maxCheck)
{
int i;
BranchSt branch;
int checkCount = 0;
Heap<BranchSt>* heap = new Heap<BranchSt>(size_);
vector<bool> checked(size_,false);
/* Search once through each tree down to root. */
for (i = 0; i < numTrees; ++i) {
searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck, heap, checked);
}
/* Keep searching other branches from heap until finished. */
while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
searchLevel(result, vec, branch.node, branch.mindistsq, checkCount, maxCheck, heap, checked);
}
delete heap;
assert(result.full());
}
/**
* Search starting from a given node of the tree. Based on any mismatches at
* higher levels, all exemplars below this level must have a distance of
* at least "mindistsq".
*/
void searchLevel(ResultSet& result, const ElementType* vec, Tree node, float mindistsq, int& checkCount, int maxCheck,
Heap<BranchSt>* heap, vector<bool>& checked)
{
if (result.worstDist()<mindistsq) {
// printf("Ignoring branch, too far\n");
return;
}
/* If this is a leaf node, then do check and return. */
if (node->child1 == NULL && node->child2 == NULL) {
/* Do not check same node more than once when searching multiple trees.
Once a vector is checked, we set its location in vind to the
current checkID.
*/
if (checked[node->divfeat] == true || checkCount>=maxCheck) {
if (result.full()) return;
}
checkCount++;
checked[node->divfeat] = true;
int index = node->divfeat;
DistanceType dist = distance(dataset[index], vec, veclen_);
result.addPoint(dist,index);
return;
}
/* Which child branch should be taken first? */
ElementType val = vec[node->divfeat];
DistanceType diff = val - node->divval;
Tree bestChild = (diff < 0) ? node->child1 : node->child2;
Tree otherChild = (diff < 0) ? node->child2 : node->child1;
/* Create a branch record for the branch not taken. Add distance
of this feature boundary (we don't attempt to correct for any
use of this feature in a parent node, which is unlikely to
happen and would have only a small effect). Don't bother
adding more branches to heap after halfway point, as cost of
adding exceeds their value.
*/
DistanceType new_distsq = mindistsq + distance.accum_dist(val, node->divval);
// if (2 * checkCount < maxCheck || !result.full()) {
if (new_distsq < result.worstDist() || !result.full()) {
heap->insert( BranchSt(otherChild, new_distsq) );
}
/* Call recursively to search next level down. */
searchLevel(result, vec, bestChild, mindistsq, checkCount, maxCheck, heap, checked);
}
/**
* Performs an exact search in the tree starting from a node.
*/
void searchLevelExact(ResultSet& result, const ElementType* vec, Tree node, float mindistsq)
{
if (mindistsq>result.worstDist()) {
return;
}
/* If this is a leaf node, then do check and return. */
if (node->child1 == NULL && node->child2 == NULL) {
/* Do not check same node more than once when searching multiple trees.
Once a vector is checked, we set its location in vind to the
current checkID.
*/
// if (vind[node->divfeat] == checkID)
// return;
// vind[node->divfeat] = checkID;
int index = node->divfeat;
DistanceType dist = distance(dataset[index], vec, veclen_);
result.addPoint(dist, index);
return;
}
/* Which child branch should be taken first? */
ElementType val = vec[node->divfeat];
DistanceType diff = val - node->divval;
Tree bestChild = (diff < 0) ? node->child1 : node->child2;
Tree otherChild = (diff < 0) ? node->child2 : node->child1;
/* Call recursively to search next level down. */
searchLevelExact(result, vec, bestChild, mindistsq);
DistanceType new_distsq = mindistsq + distance.accum_dist(val, node->divval);
searchLevelExact(result, vec, otherChild, new_distsq);
}
}; // class KDTree
}
#endif //KDTREE_H
......@@ -297,7 +297,7 @@ public:
* vec = the vector for which to search the nearest neighbors
* maxCheck = the maximum number of restarts (in a best-bin-first manner)
*/
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
{
// int maxChecks = searchParams.checks;
float epsError = 1+searchParams.eps;
......@@ -493,7 +493,7 @@ private:
/**
* Performs an exact search in the tree starting from a node.
*/
void searchLevel(ResultSet& result_set, const ElementType* vec, const NodePtr node, float mindistsq, const float epsError)
void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, float mindistsq, const float epsError)
{
/* If this is a leaf node, then do check and return. */
if (node->child1 == NULL && node->child2 == NULL) {
......
......@@ -560,7 +560,7 @@ public:
* vec = the vector for which to search the nearest neighbors
* searchParams = parameters that influence the search algorithm (checks, cb_index)
*/
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
{
int maxChecks = searchParams.checks;
......@@ -930,7 +930,7 @@ private:
*/
void findNN(KMeansNodePtr node, ResultSet& result, const ElementType* vec, int& checks, int maxChecks,
void findNN(KMeansNodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
Heap<BranchSt>* heap)
{
// Ignore those clusters that are too far away
......@@ -1010,7 +1010,7 @@ private:
/**
* Function the performs exact nearest neighbor search by traversing the entire tree.
*/
void findExactNN(KMeansNodePtr node, ResultSet& result, const ElementType* vec)
void findExactNN(KMeansNodePtr node, ResultSet<DistanceType>& result, const ElementType* vec)
{
// Ignore those clusters that are too far away
{
......
......@@ -115,7 +115,7 @@ public:
/* nothing to do here for linear search */
}
void findNeighbors(ResultSet& resultSet, const ElementType* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet<DistanceType>& resultSet, const ElementType* vec, const SearchParams& searchParams)
{
for (size_t i=0;i<dataset.rows;++i) {
DistanceType dist = distance(dataset[i],vec, dataset.cols);
......
......@@ -41,6 +41,7 @@ using namespace std;
namespace flann
{
template <typename DistanceType>
class ResultSet;
/**
......@@ -50,6 +51,7 @@ template <typename Distance>
class NNIndex
{
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
public:
......@@ -73,7 +75,7 @@ public:
/**
Method that searches for nearest-neighbors
*/
virtual void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams) = 0;
virtual void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) = 0;
/**
Number of features in this index.
......
......@@ -75,13 +75,15 @@ float search_with_ground_truth(NNIndex<Distance>& index, const Matrix<typename D
const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches, int nn, int checks,
float& time, float& dist, const Distance& distance, int skipMatches)
{
typedef typename Distance::ResultType DistanceType;
if (matches.cols<size_t(nn)) {
logger.info("matches.cols=%d, nn=%d\n",matches.cols,nn);
throw FLANNException("Ground truth is not computed for as many neighbors as requested");
}
KNNResultSet resultSet(nn+skipMatches);
KNNResultSet<DistanceType> resultSet(nn+skipMatches);
SearchParams searchParams(checks);
int correct;
......
......@@ -62,6 +62,7 @@ struct BranchStruct {
};
template <typename DistanceType>
class ResultSet
{
public:
......@@ -71,22 +72,23 @@ public:
virtual int* getNeighbors() = 0;
virtual float* getDistances() = 0;
virtual DistanceType* getDistances() = 0;
virtual size_t size() const = 0;
virtual bool full() const = 0;
virtual void addPoint(float dist, int index) = 0;
virtual void addPoint(DistanceType dist, int index) = 0;
virtual float worstDist() const = 0;
virtual DistanceType worstDist() const = 0;
};
class KNNResultSet : public ResultSet
template <typename DistanceType>
class KNNResultSet : public ResultSet<DistanceType>
{
int* indices;
float* dists;
DistanceType* dists;
int capacity;
int count;
......@@ -94,7 +96,7 @@ public:
KNNResultSet(int capacity_) : capacity(capacity_)
{
indices = new int[capacity_+1];
dists = new float[capacity_+1];
dists = new DistanceType[capacity_+1];
count = 0;
}
......@@ -107,7 +109,7 @@ public:
void init()
{
count = 0;
dists[capacity-1] = numeric_limits<float>::max();
dists[capacity-1] = (numeric_limits<DistanceType>::max) ();
}
int* getNeighbors()
......@@ -115,7 +117,7 @@ public:
return indices;
}
float* getDistances()
DistanceType* getDistances()
{
return dists;
}
......@@ -131,7 +133,7 @@ public:
}
void addPoint(float dist, int index)
void addPoint(DistanceType dist, int index)
{
// for (int i=0;i<count;++i) {
// if (indices[i]==index) return false;
......@@ -170,7 +172,7 @@ public:
// }
}
float worstDist() const
DistanceType worstDist() const
{
return dists[capacity-1];
}
......@@ -180,11 +182,12 @@ public:
/**
* A result-set class used when performing a radius based search.
*/
class RadiusResultSet : public ResultSet
template <typename DistanceType>
class RadiusResultSet : public ResultSet<DistanceType>
{
struct Item {
int index;
float dist;
DistanceType dist;
bool operator<(Item rhs) {
return dist<rhs.dist;
......@@ -192,11 +195,11 @@ class RadiusResultSet : public ResultSet
};
vector<Item> items;
float radius;
DistanceType radius;
bool sorted;
int* indices;
float* dists;
DistanceType* dists;
size_t count;
private:
......@@ -207,12 +210,12 @@ private:
if (dists!=NULL) delete[] dists;
count = items.size();
indices = new int[count];
dists = new float[count];
dists = new DistanceType[count];
}
}
public:
RadiusResultSet(float radius_) :
RadiusResultSet(DistanceType radius_) :
radius(radius_), indices(NULL), dists(NULL)
{
sorted = false;
......@@ -246,7 +249,7 @@ public:
return indices;
}
float* getDistances()
DistanceType* getDistances()
{
if (!sorted) {
sorted = true;
......@@ -270,7 +273,7 @@ public:
return true;
}
void addPoint(float dist, int index)
void addPoint(DistanceType dist, int index)
{
count++;
// Item it;
......@@ -282,7 +285,7 @@ public:
// }
}
float worstDist() const
DistanceType worstDist() const
{
return radius;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册