提交 37c9902d 编写于 作者: M Marius Muja

Preparing for release 1.6.

Added single kd-tree index that works better for low dimensionality
points, such as 3D point clouds.

Refactored entire library so that the indexes are templated on the
distance.
上级 ac168d2e
......@@ -4,14 +4,19 @@ if(COMMAND cmake_policy)
cmake_policy(SET CMP0003 NEW)
endif(COMMAND cmake_policy)
project(flann)
string(TOLOWER ${PROJECT_NAME} PROJECT_NAME_LOWER)
include(${PROJECT_SOURCE_DIR}/cmake/flann_utils.cmake)
set(FLANN_VERSION 1.5.0)
STRING(REGEX MATCHALL "[0-9]" FLANN_VERSION_PARTS "${FLANN_VERSION}")
LIST(GET FLANN_VERSION_PARTS 0 FLANN_VERSION_MAJOR)
LIST(GET FLANN_VERSION_PARTS 1 FLANN_VERSION_MINOR)
LIST(GET FLANN_VERSION_PARTS 2 FLANN_VERSION_PATCH)
DISSECT_VERSION()
GET_OS_INFO()
# Add an "uninstall" target
CONFIGURE_FILE ("${PROJECT_SOURCE_DIR}/cmake/uninstall_target.cmake.in"
"${PROJECT_BINARY_DIR}/uninstall_target.cmake" IMMEDIATE @ONLY)
ADD_CUSTOM_TARGET (uninstall "${CMAKE_COMMAND}" -P
"${PROJECT_BINARY_DIR}/uninstall_target.cmake")
if (NOT CMAKE_INSTALL_PREFIX)
set(CMAKE_INSTALL_PREFIX /usr/local)
......@@ -43,27 +48,6 @@ set(TEST_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/test)
message("USE_MPI: ${USE_MPI}")
# workaround a FindHDF5 bug
macro(find_hdf5)
find_package(HDF5)
set( HDF5_IS_PARALLEL FALSE )
foreach( _dir ${HDF5_INCLUDE_DIRS} )
if( EXISTS "${_dir}/H5pubconf.h" )
file( STRINGS "${_dir}/H5pubconf.h"
HDF5_HAVE_PARALLEL_DEFINE
REGEX "HAVE_PARALLEL 1" )
if( HDF5_HAVE_PARALLEL_DEFINE )
set( HDF5_IS_PARALLEL TRUE )
endif()
endif()
endforeach()
set( HDF5_IS_PARALLEL ${HDF5_IS_PARALLEL} CACHE BOOL
"HDF5 library compiled with parallel IO support" )
mark_as_advanced( HDF5_IS_PARALLEL )
endmacro(find_hdf5)
find_hdf5()
if (USE_MPI OR HDF5_IS_PARALLEL)
find_package(MPI)
......@@ -94,82 +78,10 @@ endif(USE_MPI)
#set the C/C++ include path to the "include" directory
include_directories(${PROJECT_SOURCE_DIR}/src/cpp)
add_custom_target(tests)
add_custom_target(test)
add_dependencies(test tests)
macro(flann_add_gtest exe)
# add build target
add_executable(${exe} EXCLUDE_FROM_ALL ${ARGN})
target_link_libraries(${exe} gtest)
# add dependency to 'tests' target
add_dependencies(tests ${exe})
# add target for running test
string(REPLACE "/" "_" _testname ${exe})
add_custom_target(test_${_testname}
COMMAND ${exe}
ARGS --gtest_print_time
DEPENDS ${exe}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/test
VERBATIM
COMMENT "Runnint gtest test(s) ${exe}")
# add dependency to 'test' target
add_dependencies(test test_${_testname})
endmacro(flann_add_gtest)
macro(flann_add_pyunit file)
# find test file
set(_file_name _file_name-NOTFOUND)
find_file(_file_name ${file} ${CMAKE_CURRENT_SOURCE_DIR})
if(NOT _file_name)
message(FATAL_ERROR "Can't find pyunit file \"${file}\"")
endif(NOT _file_name)
# find python
find_package(PythonInterp REQUIRED)
# add target for running test
string(REPLACE "/" "_" _testname ${file})
add_custom_target(pyunit_${_testname}
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/bin/run_test.py ${_file_name}
DEPENDS ${_file_name}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/test
VERBATIM
COMMENT "Running pyunit test(s) ${file}" )
# add dependency to 'test' target
add_dependencies(pyunit_${_testname} flann)
add_dependencies(test pyunit_${_testname})
endmacro(flann_add_pyunit)
macro(flann_download_test_data _name _md5)
string(REPLACE "/" "_" _dataset_name dataset_${_name})
# find python
find_package(PythonInterp REQUIRED)
add_custom_target(${_dataset_name}
COMMAND ${PROJECT_SOURCE_DIR}/bin/download_checkmd5.py http://people.cs.ubc.ca/~mariusm/uploads/FLANN/datasets/${_name} ${TEST_OUTPUT_PATH}/${_name} ${_md5}
VERBATIM)
# Also make sure that downloads are done before we run any tests
add_dependencies(tests ${_dataset_name})
endmacro(flann_download_test_data)
# require proper c++
add_definitions( "-Wall -ansi -pedantic" )
add_subdirectory( cmake )
add_subdirectory( src )
add_subdirectory( examples )
add_subdirectory( test )
......
......@@ -4,20 +4,22 @@
#include <stdio.h>
using namespace flann;
int main(int argc, char** argv)
{
int nn = 3;
flann::Matrix<float> dataset;
flann::Matrix<float> query;
flann::load_from_file(dataset, "dataset.hdf5","dataset");
flann::load_from_file(query, "dataset.hdf5","query");
Matrix<float> dataset;
Matrix<float> query;
load_from_file(dataset, "dataset.hdf5","dataset");
load_from_file(query, "dataset.hdf5","query");
flann::Matrix<int> indices(new int[query.rows*nn], query.rows, nn);
flann::Matrix<float> dists(new float[query.rows*nn], query.rows, nn);
Matrix<int> indices(new int[query.rows*nn], query.rows, nn);
Matrix<float> dists(new float[query.rows*nn], query.rows, nn);
// construct an randomized kd-tree index using 4 kd-trees
flann::Index<float> index(dataset, flann::KDTreeIndexParams(4));
Index<L2<float> > index(dataset, flann::KDTreeIndexParams(4));
index.buildIndex();
// do a knn search, using 128 checks
......
#include_directories(${CMAKE_SOURCE_DIR}/include algorithms util nn .)
SET(FLANN_SOVERSION "${FLANN_VERSION_MAJOR}.${FLANN_VERSION_MINOR}")
file(GLOB_RECURSE C_SOURCES *.cpp)
file(GLOB_RECURSE CPP_SOURCES *.cpp)
......@@ -13,15 +11,15 @@ list(REMOVE_ITEM C_SOURCES ${CPP_FLANN})
add_library(flann SHARED ${C_SOURCES})
set_target_properties(flann PROPERTIES
VERSION ${FLANN_VERSION}
SOVERSION ${FLANN_SOVERSION}
set_target_properties(flann PROPERTIES
VERSION ${FLANN_VERSION}
SOVERSION ${FLANN_SOVERSION}
)
add_library(flann_cpp SHARED ${CPP_SOURCES})
set_target_properties(flann_cpp PROPERTIES
VERSION ${FLANN_VERSION}
SOVERSION ${FLANN_SOVERSION}
)
set_target_properties(flann_cpp PROPERTIES
VERSION ${FLANN_VERSION}
SOVERSION ${FLANN_SOVERSION}
)
add_library(flann_s STATIC ${C_SOURCES})
add_library(flann_cpp_s STATIC ${CPP_SOURCES})
......@@ -29,15 +27,15 @@ add_library(flann_cpp_s STATIC ${CPP_SOURCES})
if(WIN32)
install (
TARGETS flann
RUNTIME DESTINATION matlab
RUNTIME DESTINATION matlab
)
endif(WIN32)
install (
TARGETS flann flann_cpp flann_s flann_cpp_s
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
LIBRARY DESTINATION ${FLANN_LIB_INSTALL_DIR}
ARCHIVE DESTINATION ${FLANN_LIB_INSTALL_DIR}
)
install (
......
......@@ -34,7 +34,7 @@
#include "flann/algorithms/nn_index.h"
#include "flann/algorithms/kdtree_index.h"
#include "flann/algorithms/kdtree_index2.h"
#include "flann/algorithms/kdtree_simple_index.h"
#include "flann/algorithms/kmeans_index.h"
#include "flann/algorithms/composite_index.h"
#include "flann/algorithms/linear_index.h"
......@@ -43,31 +43,31 @@
namespace flann {
template<typename T>
NNIndex<T>* create_index_by_type(const Matrix<T>& dataset, const IndexParams& params)
template<typename Distance>
NNIndex<Distance>* create_index_by_type(const Matrix<typename Distance::ElementType>& dataset, const IndexParams& params, const Distance& distance)
{
flann_algorithm_t index_type = params.getIndexType();
NNIndex<T>* nnIndex;
NNIndex<Distance>* nnIndex;
switch (index_type) {
case LINEAR:
nnIndex = new LinearIndex<T>(dataset, (const LinearIndexParams&)params);
nnIndex = new LinearIndex<Distance>(dataset, (const LinearIndexParams&)params, distance);
break;
case KDTREE2:
nnIndex = new KDTreeIndex2<T>(dataset, (const KDTreeIndex2Params&)params);
case KDTREE_SIMPLE:
nnIndex = new KDTreeSimpleIndex<Distance>(dataset, (const KDTreeSimpleIndexParams&)params, distance);
break;
case KDTREE:
nnIndex = new KDTreeIndex<T>(dataset, (const KDTreeIndexParams&)params);
nnIndex = new KDTreeIndex<Distance>(dataset, (const KDTreeIndexParams&)params, distance);
break;
case KMEANS:
nnIndex = new KMeansIndex<T>(dataset, (const KMeansIndexParams&)params);
nnIndex = new KMeansIndex<Distance>(dataset, (const KMeansIndexParams&)params, distance);
break;
case COMPOSITE:
nnIndex = new CompositeIndex<T>(dataset, (const CompositeIndexParams&) params);
break;
case AUTOTUNED:
nnIndex = new AutotunedIndex<T>(dataset, (const AutotunedIndexParams&) params);
nnIndex = new CompositeIndex<Distance>(dataset, (const CompositeIndexParams&) params, distance);
break;
// case AUTOTUNED:
// nnIndex = new AutotunedIndex<Distance>(dataset, (const AutotunedIndexParams&) params, distance);
// break;
default:
throw FLANNException("Unknown index type");
}
......
......@@ -27,7 +27,6 @@
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*************************************************************************/
#ifndef AUTOTUNEDINDEX_H_
#define AUTOTUNEDINDEX_H_
......@@ -41,6 +40,37 @@
namespace flann
{
template<typename Distance>
NNIndex<Distance>* index_by_type(const Matrix<typename Distance::ElementType>& dataset, const IndexParams& params, const Distance& distance)
{
flann_algorithm_t index_type = params.getIndexType();
NNIndex<Distance>* nnIndex;
switch (index_type) {
case LINEAR:
nnIndex = new LinearIndex<Distance>(dataset, (const LinearIndexParams&)params, distance);
break;
case KDTREE_SIMPLE:
nnIndex = new KDTreeSimpleIndex<Distance>(dataset, (const KDTreeSimpleIndexParams&)params, distance);
break;
case KDTREE:
nnIndex = new KDTreeIndex<Distance>(dataset, (const KDTreeIndexParams&)params, distance);
break;
case KMEANS:
nnIndex = new KMeansIndex<Distance>(dataset, (const KMeansIndexParams&)params, distance);
break;
case COMPOSITE:
nnIndex = new CompositeIndex<Distance>(dataset, (const CompositeIndexParams&) params, distance);
break;
default:
throw FLANNException("Unknown index type");
}
return nnIndex;
}
struct AutotunedIndexParams : public IndexParams {
AutotunedIndexParams( float target_precision_ = 0.8, float build_weight_ = 0.01,
float memory_weight_ = 0, float sample_fraction_ = 0.1) :
......@@ -55,8 +85,6 @@ struct AutotunedIndexParams : public IndexParams {
float memory_weight; // index memory weighting factor
float sample_fraction; // what fraction of the dataset to use for autotuning
flann_algorithm_t getIndexType() const { return algorithm; }
void fromParameters(const FLANNParameters& p)
{
assert(p.algorithm==algorithm);
......@@ -86,16 +114,19 @@ struct AutotunedIndexParams : public IndexParams {
};
template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type >
class AutotunedIndex : public NNIndex<ELEM_TYPE>
template <typename Distance>
class AutotunedIndex : public NNIndex<Distance>
{
NNIndex<ELEM_TYPE>* bestIndex;
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
NNIndex<Distance>* bestIndex;
IndexParams* bestParams;
SearchParams bestSearchParams;
Matrix<ELEM_TYPE> sampledDataset;
Matrix<ELEM_TYPE> testDataset;
Matrix<ElementType> sampledDataset;
Matrix<ElementType> testDataset;
Matrix<int> gt_matches;
float speedup;
......@@ -103,17 +134,19 @@ class AutotunedIndex : public NNIndex<ELEM_TYPE>
/**
* The dataset used by this index
*/
const Matrix<ELEM_TYPE> dataset;
const Matrix<ElementType> dataset;
/**
* Index parameters
*/
const AutotunedIndexParams& index_params;
Distance distance;
public:
AutotunedIndex(const Matrix<ELEM_TYPE>& inputData, const AutotunedIndexParams& params = AutotunedIndexParams() ) :
dataset(inputData), index_params(params)
AutotunedIndex(const Matrix<ElementType>& inputData, const AutotunedIndexParams& params = AutotunedIndexParams(),
Distance d = Distance()) :
dataset(inputData), index_params(params), distance(d)
{
bestIndex = NULL;
bestParams = NULL;
......@@ -142,16 +175,16 @@ public:
flann_algorithm_t index_type = bestParams->getIndexType();
switch (index_type) {
case LINEAR:
bestIndex = new LinearIndex<ELEM_TYPE>(dataset, (const LinearIndexParams&)*bestParams);
bestIndex = new LinearIndex<Distance>(dataset, (const LinearIndexParams&)*bestParams, distance);
break;
case KDTREE:
bestIndex = new KDTreeIndex<ELEM_TYPE>(dataset, (const KDTreeIndexParams&)*bestParams);
bestIndex = new KDTreeIndex<Distance>(dataset, (const KDTreeIndexParams&)*bestParams, distance);
break;
case KMEANS:
bestIndex = new KMeansIndex<ELEM_TYPE>(dataset, (const KMeansIndexParams&)*bestParams);
bestIndex = new KMeansIndex<Distance>(dataset, (const KMeansIndexParams&)*bestParams, distance);
break;
default:
throw FLANNException("Unknown algorithm choosen by the autotuning, most likely a bug.");
throw FLANNException("Unknown algorithm chosen by the autotuning, most likely a bug.");
}
bestIndex->buildIndex();
speedup = estimateSearchParams(bestSearchParams);
......@@ -175,7 +208,7 @@ public:
int index_type;
load_value(stream,index_type);
IndexParams* params = ParamsFactory::instance().create((flann_algorithm_t)index_type);
bestIndex = create_index_by_type(dataset, *params);
bestIndex = index_by_type<Distance>(dataset, *params, distance);
bestIndex->loadIndex(stream);
load_value(stream, bestSearchParams);
}
......@@ -183,7 +216,7 @@ public:
/**
Method that searches for nearest-neighbors
*/
virtual void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams)
virtual void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
{
if (searchParams.checks==-2) {
bestIndex->findNeighbors(result, vec, bestSearchParams);
......@@ -253,7 +286,7 @@ private:
const int nn = 1;
logger.info("KMeansTree using params: max_iterations=%d, branching=%d\n", kmeans_params.iterations, kmeans_params.branching);
KMeansIndex<ELEM_TYPE> kmeans(sampledDataset, kmeans_params);
KMeansIndex<Distance> kmeans(sampledDataset, kmeans_params);
// measure index build time
t.start();
kmeans.buildIndex();
......@@ -279,7 +312,7 @@ private:
const int nn = 1;
logger.info("KDTree using params: trees=%d\n",kdtree_params.trees);
KDTreeIndex<ELEM_TYPE> kdtree(sampledDataset, kdtree_params);
KDTreeIndex<Distance> kdtree(sampledDataset, kdtree_params);
t.start();
kdtree.buildIndex();
......@@ -523,7 +556,7 @@ private:
gt_matches = Matrix<int>(new int[testDataset.rows],testDataset.rows, 1);
StartStopTimer t;
t.start();
compute_ground_truth(sampledDataset, testDataset, gt_matches, 0);
compute_ground_truth<Distance>(sampledDataset, testDataset, gt_matches, 0, distance);
t.stop();
float bestCost = t.value;
IndexParams* bestParams = new LinearIndexParams();
......@@ -570,7 +603,7 @@ private:
int samples = min(dataset.rows/10, SAMPLE_COUNT);
if (samples>0) {
Matrix<ELEM_TYPE> testDataset = random_sample(dataset,samples);
Matrix<ElementType> testDataset = random_sample(dataset,samples);
logger.info("Computing ground truth\n");
......@@ -578,7 +611,7 @@ private:
Matrix<int> gt_matches(new int[testDataset.rows],testDataset.rows,1);
StartStopTimer t;
t.start();
compute_ground_truth(dataset, testDataset, gt_matches,1);
compute_ground_truth<Distance>(dataset, testDataset, gt_matches, 1, distance);
t.stop();
float linear = t.value;
......@@ -589,7 +622,7 @@ private:
float cb_index;
if (bestIndex->getType() == KMEANS) {
logger.info("KMeans algorithm, estimating cluster border factor\n");
KMeansIndex<ELEM_TYPE>* kmeans = (KMeansIndex<ELEM_TYPE>*)bestIndex;
KMeansIndex<Distance>* kmeans = (KMeansIndex<Distance>*)bestIndex;
float bestSearchTime = -1;
float best_cb_index = -1;
int best_checks = -1;
......
......@@ -33,6 +33,8 @@
#include "flann/general.h"
#include "flann/algorithms/nn_index.h"
#include "flann/algorithms/kdtree_index.h"
#include "flann/algorithms/kmeans_index.h"
namespace flann
{
......@@ -89,27 +91,32 @@ struct CompositeIndexParams : public IndexParams {
template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type >
class CompositeIndex : public NNIndex<ELEM_TYPE>
template <typename Distance>
class CompositeIndex : public NNIndex<Distance>
{
KMeansIndex<ELEM_TYPE, DIST_TYPE>* kmeans;
KDTreeIndex<ELEM_TYPE, DIST_TYPE>* kdtree;
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
const Matrix<ELEM_TYPE> dataset;
KMeansIndex<Distance>* kmeans;
KDTreeIndex<Distance>* kdtree;
const Matrix<ElementType> dataset;
const IndexParams& index_params;
Distance distance;
public:
CompositeIndex(const Matrix<ELEM_TYPE>& inputData, const CompositeIndexParams& params = CompositeIndexParams() ) :
dataset(inputData), index_params(params)
CompositeIndex(const Matrix<ElementType>& inputData, const CompositeIndexParams& params = CompositeIndexParams(),
Distance d = Distance()) :
dataset(inputData), index_params(params), distance(d)
{
KDTreeIndexParams kdtree_params(params.trees);
KMeansIndexParams kmeans_params(params.branching, params.iterations, params.centers_init, params.cb_index);
kdtree = new KDTreeIndex<ELEM_TYPE, DIST_TYPE>(inputData,kdtree_params);
kmeans = new KMeansIndex<ELEM_TYPE, DIST_TYPE>(inputData,kmeans_params);
kdtree = new KDTreeIndex<Distance>(inputData,kdtree_params, d);
kmeans = new KMeansIndex<Distance>(inputData,kmeans_params, d);
}
......@@ -164,7 +171,7 @@ public:
kdtree->loadIndex(stream);
}
void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
{
kmeans->findNeighbors(result,vec,searchParams);
kdtree->findNeighbors(result,vec,searchParams);
......
/***********************************************************************
* Software License Agreement (BSD License)
*
* Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
* Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
*
* THE BSD LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*************************************************************************/
#include "flann/algorithms/dist.h"
#include <cstdio>
namespace flann
{
/** Global variable indicating the distance metric
* to be used.
*/
flann_distance_t flann_distance_type = EUCLIDEAN;
/**
* Zero iterator that emulates a zero feature.
*/
ZeroIterator<float> zero;
/**
* Order of Minkowski distance to use.
*/
int flann_minkowski_order;
double euclidean_dist(const unsigned char* first1, const unsigned char* last1, unsigned char* first2, double acc)
{
double distsq = acc;
double diff0, diff1, diff2, diff3;
const unsigned char* lastgroup = last1 - 3;
while (first1 < lastgroup) {
diff0 = first1[0] - first2[0];
diff1 = first1[1] - first2[1];
diff2 = first1[2] - first2[2];
diff3 = first1[3] - first2[3];
distsq += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
first1 += 4;
first2 += 4;
}
while (first1 < last1) {
diff0 = *first1++ - *first2++;
distsq += diff0 * diff0;
}
return distsq;
}
}
此差异已折叠。
......@@ -52,9 +52,11 @@ namespace flann
{
struct KDTreeIndexParams : public IndexParams {
KDTreeIndexParams(int trees_ = 4) : IndexParams(KDTREE), trees(trees_) {};
KDTreeIndexParams(int trees_ = 4, int leaf_max_size_ = 4) :
IndexParams(KDTREE), trees(trees_), leaf_max_size(leaf_max_size_) {};
int trees; // number of randomized trees to use (for kdtree)
int leaf_max_size;
flann_algorithm_t getIndexType() const { return algorithm; }
......@@ -85,9 +87,11 @@ struct KDTreeIndexParams : public IndexParams {
* Contains the k-d trees and other information for indexing a set of points
* for nearest-neighbor matching.
*/
template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type >
class KDTreeIndex : public NNIndex<ELEM_TYPE>
template <typename Distance>
class KDTreeIndex : public NNIndex<Distance>
{
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
enum {
/**
......@@ -115,13 +119,15 @@ class KDTreeIndex : public NNIndex<ELEM_TYPE>
/**
* Array of indices to vectors in the dataset.
*/
int* vind;
int** vind;
int leaf_max_size_;
/**
* The dataset used by this index
*/
const Matrix<ELEM_TYPE> dataset;
const Matrix<ElementType> dataset;
const IndexParams& index_params;
......@@ -129,41 +135,36 @@ class KDTreeIndex : public NNIndex<ELEM_TYPE>
size_t veclen_;
DIST_TYPE* mean;
DIST_TYPE* var;
DistanceType* mean;
DistanceType* var;
/*--------------------- Internal Data Structures --------------------------*/
/**
* A node of the binary k-d tree.
*
* This is All nodes that have vec[divfeat] < divval are placed in the
* child1 subtree, else child2., A leaf node is indicated if both children are NULL.
*/
struct TreeSt {
/**
* Index of the vector feature used for subdivision.
* If this is a leaf node (both children are NULL) then
* this holds vector index for this leaf.
*/
int divfeat;
struct Node {
int *ind;
int count;
/**
* Dimension used for subdivision.
*/
int divfeat;
/**
* The value used for subdivision.
* The values used for subdivision.
*/
DIST_TYPE divval;
DistanceType divval;
/**
* The child nodes.
*/
TreeSt *child1, *child2;
};
typedef TreeSt* Tree;
Node *child1, *child2;
};
typedef Node* NodePtr;
/**
* Array of k-d trees used to find neighbours.
*/
Tree* trees;
typedef BranchStruct<Tree> BranchSt;
NodePtr* trees;
typedef BranchStruct<NodePtr> BranchSt;
typedef BranchSt* Branch;
/**
......@@ -175,7 +176,7 @@ class KDTreeIndex : public NNIndex<ELEM_TYPE>
*/
PooledAllocator pool;
Distance distance;
public:
......@@ -191,33 +192,25 @@ public:
* inputData = dataset with the input features
* params = parameters passed to the kdtree algorithm
*/
KDTreeIndex(const Matrix<ELEM_TYPE>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams() ) :
dataset(inputData), index_params(params)
KDTreeIndex(const Matrix<ElementType>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams(),
Distance d = Distance() ) :
dataset(inputData), index_params(params), distance(d)
{
size_ = dataset.rows;
veclen_ = dataset.cols;
numTrees = params.trees;
trees = new Tree[numTrees];
// get the parameters
// if (params.find("trees") != params.end()) {
// numTrees = (int)params["trees"];
// trees = new Tree[numTrees];
// }
// else {
// numTrees = -1;
// trees = NULL;
// }
leaf_max_size_ = params.leaf_max_size;
trees = new NodePtr[numTrees];
// Create a permutable array of indices to the input vectors.
vind = new int[size_];
for (size_t i = 0; i < size_; i++) {
vind[i] = i;
vind = new int*[numTrees];
for (int i = 0; i < numTrees; i++) {
vind[i] = NULL;
}
mean = new DIST_TYPE[veclen_];
var = new DIST_TYPE[veclen_];
mean = new DistanceType[veclen_];
var = new DistanceType[veclen_];
}
/**
......@@ -225,6 +218,9 @@ public:
*/
~KDTreeIndex()
{
for (int i = 0; i < numTrees; i++) {
if (vind[i]!=NULL) delete[] vind[i];
}
delete[] vind;
if (trees!=NULL) {
delete[] trees;
......@@ -234,6 +230,15 @@ public:
}
template <typename Vector>
void randomizeVector(Vector& vec, int vec_size)
{
for (int j = vec_size; j > 0; --j) {
int rnd = rand_int(j);
swap(vec[j-1], vec[rnd]);
}
}
/**
* Builds the index
*/
......@@ -242,16 +247,15 @@ public:
/* Construct the randomized trees. */
for (int i = 0; i < numTrees; i++) {
/* Randomize the order of vectors to allow for unbiased sampling. */
for (int j = size_; j > 0; --j) {
int rnd = rand_int(j);
swap(vind[j-1], vind[rnd]);
vind[i] = new int[size_];
for (size_t j = 0; j < size_; j++) {
vind[i][j] = j;
}
trees[i] = divideTree(0, size_ - 1);
randomizeVector(vind[i], size_);
trees[i] = divideTree(vind[i], size_ );
}
}
void saveIndex(FILE* stream)
{
save_value(stream, numTrees);
......@@ -265,17 +269,15 @@ public:
void loadIndex(FILE* stream)
{
load_value(stream, numTrees);
if (trees!=NULL) {
delete[] trees;
}
trees = new Tree[numTrees];
trees = new NodePtr[numTrees];
for (int i=0;i<numTrees;++i) {
load_tree(stream,trees[i]);
}
}
/**
* Returns size of index.
*/
......@@ -292,7 +294,6 @@ public:
return veclen_;
}
/**
* Computes the inde memory usage
* Returns: memory used by the index
......@@ -302,7 +303,6 @@ public:
return pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int); // pool memory and vind array memory
}
/**
* Find set of nearest neighbors to vec. Their indices are stored inside
* the result object.
......@@ -312,14 +312,15 @@ public:
* vec = the vector for which to search the nearest neighbors
* maxCheck = the maximum number of restarts (in a best-bin-first manner)
*/
void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
{
int maxChecks = searchParams.checks;
float epsError = 1+searchParams.eps;
if (maxChecks<0) {
getExactNeighbors(result, vec);
getExactNeighbors(result, vec, epsError);
} else {
getNeighbors(result, vec, maxChecks);
getNeighbors(result, vec, maxChecks, epsError);
}
}
......@@ -331,7 +332,7 @@ public:
private:
void save_tree(FILE* stream, Tree tree)
void save_tree(FILE* stream, NodePtr tree)
{
save_value(stream, *tree);
if (tree->child1!=NULL) {
......@@ -343,9 +344,9 @@ private:
}
void load_tree(FILE* stream, Tree& tree)
void load_tree(FILE* stream, NodePtr& tree)
{
tree = pool.allocate<TreeSt>();
tree = pool.allocate<Node>();
load_value(stream, *tree);
if (tree->child1!=NULL) {
load_tree(stream, tree->child1);
......@@ -365,18 +366,26 @@ private:
* first = index of the first vector
* last = index of the last vector
*/
Tree divideTree(int first, int last)
NodePtr divideTree(int* ind, int count)
{
Tree node = pool.allocate<TreeSt>(); // allocate memory
NodePtr node = pool.allocate<Node>(); // allocate memory
/* If only one exemplar remains, then make this a leaf node. */
if (first == last) {
/* If too few exemplars remain, then make this a leaf node. */
if ( count <= leaf_max_size_) {
node->child1 = node->child2 = NULL; /* Mark as leaf node. */
node->divfeat = vind[first]; /* Store index of this vec. */
node->ind = ind; /* Store index of this vec. */
node->count = count; /* and length */
}
else {
chooseDivision(node, first, last);
subdivide(node, first, last);
int idx;
int cutfeat;
DistanceType cutval;
meanSplit(ind, count, idx, cutfeat, cutval);
node->divfeat = cutfeat;
node->divval = cutval;
node->child1 = divideTree(ind, idx);
node->child2 = divideTree(ind+idx, count-idx);
}
return node;
......@@ -388,37 +397,46 @@ private:
* Make a random choice among those with the highest variance, and use
* its variance as the threshold value.
*/
void chooseDivision(Tree node, int first, int last)
void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
{
memset(mean,0,veclen_*sizeof(DIST_TYPE));
memset(var,0,veclen_*sizeof(DIST_TYPE));
memset(mean,0,veclen_*sizeof(DistanceType));
memset(var,0,veclen_*sizeof(DistanceType));
/* Compute mean values. Only the first SAMPLE_MEAN values need to be
sampled to get a good estimate.
*/
int end = min(first + SAMPLE_MEAN, last);
for (int j = first; j <= end; ++j) {
ELEM_TYPE* v = dataset[vind[j]];
int cnt = min((int)SAMPLE_MEAN+1, count);
for (int j = 0; j < cnt; ++j) {
ElementType* v = dataset[ind[j]];
for (size_t k=0; k<veclen_; ++k) {
mean[k] += v[k];
}
}
for (size_t k=0; k<veclen_; ++k) {
mean[k] /= (end - first + 1);
mean[k] /= cnt;
}
/* Compute variances (no need to divide by count). */
for (int j = first; j <= end; ++j) {
ELEM_TYPE* v = dataset[vind[j]];
for (int j = 0; j < cnt; ++j) {
ElementType* v = dataset[ind[j]];
for (size_t k=0; k<veclen_; ++k) {
DIST_TYPE dist = v[k] - mean[k];
DistanceType dist = v[k] - mean[k];
var[k] += dist * dist;
}
}
/* Select one of the highest variance indices at random. */
node->divfeat = selectDivision(var);
node->divval = mean[node->divfeat];
cutfeat = selectDivision(var);
cutval = mean[cutfeat];
int lim1, lim2;
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
if (lim1>count/2) index = lim1;
else if (lim2<count/2) index = lim2;
else index = count/2;
// in the unlikely case all values are equal
if (lim1==cnt || lim2==0) index = count/2;
}
......@@ -426,7 +444,7 @@ private:
* Select the top RAND_DIM largest values from v and return the index of
* one of these selected at random.
*/
int selectDivision(DIST_TYPE* v)
int selectDivision(DistanceType* v)
{
int num = 0;
int topind[RAND_DIM];
......@@ -456,44 +474,44 @@ private:
/**
* Subdivide the list of exemplars using the feature and division
* value given in this node. Call divideTree recursively on each list.
* Subdivide the list of points by a plane perpendicular on axe corresponding
* to the 'cutfeat' dimension at 'cutval' position.
*
* On return:
* dataset[ind[0..lim1-1]][cutfeat]<cutval
* dataset[ind[lim1..lim2-1]][cutfeat]==cutval
* dataset[ind[lim2..count]][cutfeat]>cutval
*/
void subdivide(Tree node, int first, int last)
void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
{
/* Move vector indices for left subtree to front of list. */
int i = first;
int j = last;
while (i <= j) {
int ind = vind[i];
ELEM_TYPE val = dataset[ind][node->divfeat];
if (val < node->divval) {
++i;
} else {
/* Move to end of list by swapping vind i and j. */
swap(vind[i], vind[j]);
--j;
}
int left = 0;
int right = count-1;
for (;;) {
while (left<=right && dataset[ind[left]][cutfeat]<cutval) ++left;
while (left<=right && dataset[ind[right]][cutfeat]>=cutval) --right;
if (left>right) break;
swap(ind[left], ind[right]); ++left; --right;
}
/* If either list is empty, it means we have hit the unlikely case
in which all remaining features are identical. Split in the middle
to maintain a balanced tree.
/* If either list is empty, it means that all remaining features
* are identical. Split in the middle to maintain a balanced tree.
*/
if ( (i == first) || (i == last+1)) {
i = (first+last+1)/2;
lim1 = left;
right = count-1;
for (;;) {
while (left<=right && dataset[ind[left]][cutfeat]<=cutval) ++left;
while (left<=right && dataset[ind[right]][cutfeat]>cutval) --right;
if (left>right) break;
swap(ind[left], ind[right]); ++left; --right;
}
node->child1 = divideTree(first, i - 1);
node->child2 = divideTree(i, last);
lim2 = left;
}
/**
* Performs an exact nearest neighbor search. The exact search performs a full
* traversal of the tree.
*/
void getExactNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec)
void getExactNeighbors(ResultSet& result, const ElementType* vec, float epsError)
{
// checkID -= 1; /* Set a different unique ID for each search. */
......@@ -501,7 +519,7 @@ private:
fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
}
if (numTrees>0) {
searchLevelExact(result, vec, trees[0], 0.0);
searchLevelExact(result, vec, trees[0], 0.0, epsError);
}
assert(result.full());
}
......@@ -511,7 +529,7 @@ private:
* because the tree traversal is abandoned after a given number of descends in
* the tree.
*/
void getNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, int maxCheck)
void getNeighbors(ResultSet& result, const ElementType* vec, int maxCheck, float epsError)
{
int i;
BranchSt branch;
......@@ -522,12 +540,12 @@ private:
/* Search once through each tree down to root. */
for (i = 0; i < numTrees; ++i) {
searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck, heap, checked);
searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck, epsError, heap, checked);
}
/* Keep searching other branches from heap until finished. */
while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
searchLevel(result, vec, branch.node, branch.mindistsq, checkCount, maxCheck, heap, checked);
searchLevel(result, vec, branch.node, branch.mindistsq, checkCount, maxCheck, epsError, heap, checked);
}
delete heap;
......@@ -541,36 +559,42 @@ private:
* higher levels, all exemplars below this level must have a distance of
* at least "mindistsq".
*/
void searchLevel(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq, int& checkCount, int maxCheck,
Heap<BranchSt>* heap, vector<bool>& checked)
void searchLevel(ResultSet& result_set, const ElementType* vec, NodePtr node, float mindistsq, int& checkCount, int maxCheck,
float epsError, Heap<BranchSt>* heap, vector<bool>& checked)
{
if (result.worstDist()<mindistsq) {
if (result_set.worstDist()<mindistsq) {
// printf("Ignoring branch, too far\n");
return;
}
/* If this is a leaf node, then do check and return. */
if (node->child1 == NULL && node->child2 == NULL) {
/* Do not check same node more than once when searching multiple trees.
/* Do not check same node more than once when searching multiple trees.
Once a vector is checked, we set its location in vind to the
current checkID.
*/
if (checked[node->divfeat] == true || checkCount>=maxCheck) {
if (result.full()) return;
}
checkCount++;
checked[node->divfeat] = true;
float worst_dist = result_set.worstDist();
for (int i=0;i<node->count;++i) {
int index = node->ind[i];
if (checked[index] == true || checkCount>=maxCheck) {
if (result_set.full()) continue;
}
checked[index] = true;
checkCount++;
result.addPoint(dataset[node->divfeat],node->divfeat);
DistanceType dist = distance(dataset[index], vec, veclen_);
if (dist<worst_dist) {
result_set.addPoint(dist,index);
}
}
return;
}
/* Which child branch should be taken first? */
ELEM_TYPE val = vec[node->divfeat];
DIST_TYPE diff = val - node->divval;
Tree bestChild = (diff < 0) ? node->child1 : node->child2;
Tree otherChild = (diff < 0) ? node->child2 : node->child1;
ElementType val = vec[node->divfeat];
DistanceType diff = val - node->divval;
NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
/* Create a branch record for the branch not taken. Add distance
of this feature boundary (we don't attempt to correct for any
......@@ -580,54 +604,59 @@ private:
adding exceeds their value.
*/
DIST_TYPE new_distsq = flann_dist(&val, &val+1, &node->divval, mindistsq);
DistanceType new_distsq = mindistsq + distance.accum_dist(val, node->divval);
// if (2 * checkCount < maxCheck || !result.full()) {
if (new_distsq < result.worstDist() || !result.full()) {
heap->insert( BranchSt::make_branch(otherChild, new_distsq) );
if (new_distsq*epsError < result_set.worstDist() || !result_set.full()) {
heap->insert( BranchSt(otherChild, new_distsq) );
}
/* Call recursively to search next level down. */
searchLevel(result, vec, bestChild, mindistsq, checkCount, maxCheck, heap, checked);
searchLevel(result_set, vec, bestChild, mindistsq, checkCount, maxCheck, epsError, heap, checked);
}
/**
* Performs an exact search in the tree starting from a node.
*/
void searchLevelExact(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, Tree node, float mindistsq)
void searchLevelExact(ResultSet& result_set, const ElementType* vec, const NodePtr node, float mindistsq, const float epsError)
{
if (mindistsq>result.worstDist()) {
return;
}
/* If this is a leaf node, then do check and return. */
if (node->child1 == NULL && node->child2 == NULL) {
/* Do not check same node more than once when searching multiple trees.
Once a vector is checked, we set its location in vind to the
current checkID.
*/
// if (vind[node->divfeat] == checkID)
// return;
// vind[node->divfeat] = checkID;
result.addPoint(dataset[node->divfeat],node->divfeat);
float worst_dist = result_set.worstDist();
for (int i=0;i<node->count;++i) {
int index = node->ind[i];
DistanceType dist = distance(dataset[index], vec, veclen_);
if (dist<worst_dist) {
result_set.addPoint(dist,index);
}
}
return;
}
/* Which child branch should be taken first? */
ELEM_TYPE val = vec[node->divfeat];
DIST_TYPE diff = val - node->divval;
Tree bestChild = (diff < 0) ? node->child1 : node->child2;
Tree otherChild = (diff < 0) ? node->child2 : node->child1;
ElementType val = vec[node->divfeat];
DistanceType diff = val - node->divval;
NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
/* Create a branch record for the branch not taken. Add distance
of this feature boundary (we don't attempt to correct for any
use of this feature in a parent node, which is unlikely to
happen and would have only a small effect). Don't bother
adding more branches to heap after halfway point, as cost of
adding exceeds their value.
*/
DistanceType new_distsq = mindistsq + distance.accum_dist(val, node->divval);
/* Call recursively to search next level down. */
searchLevelExact(result, vec, bestChild, mindistsq);
DIST_TYPE new_distsq = flann_dist(&val, &val+1, &node->divval, mindistsq);
searchLevelExact(result, vec, otherChild, new_distsq);
searchLevelExact(result_set, vec, bestChild, mindistsq, epsError);
if (new_distsq*epsError<=result_set.worstDist()) {
searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
}
}
}; // class KDTree
}; // class KDTreeForest
}
......
/***********************************************************************
* Software License Agreement (BSD License)
*
* Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
* Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
*
* THE BSD LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*************************************************************************/
#ifndef KDTREESIMPLE_H
#define KDTREESIMPLE_H
#include <algorithm>
#include <map>
#include <cassert>
#include <cstring>
#include "flann/general.h"
#include "flann/algorithms/nn_index.h"
#include "flann/util/matrix.h"
#include "flann/util/result_set.h"
#include "flann/util/heap.h"
#include "flann/util/allocator.h"
#include "flann/util/random.h"
#include "flann/util/saving.h"
using namespace std;
namespace flann
{
struct KDTreeSimpleIndexParams : public IndexParams {
KDTreeSimpleIndexParams(int leaf_max_size_ = 1) :
IndexParams(KDTREE_SIMPLE), leaf_max_size(leaf_max_size_) {};
int leaf_max_size;
flann_algorithm_t getIndexType() const { return algorithm; }
void fromParameters(const FLANNParameters& p)
{
assert(p.algorithm==algorithm);
// trees = p.trees;
}
void toParameters(FLANNParameters& p) const
{
p.algorithm = algorithm;
// p.trees = trees;
}
void print() const
{
logger.info("Index type: %d\n",(int)algorithm);
// logger.info("Trees: %d\n", trees);
}
};
/**
* Randomized kd-tree index
*
* Contains the k-d trees and other information for indexing a set of points
* for nearest-neighbor matching.
*/
template <typename Distance>
class KDTreeSimpleIndex : public NNIndex<Distance>
{
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
/**
* Array of indices to vectors in the dataset.
*/
int* vind;
int leaf_max_size_;
/**
* The dataset used by this index
*/
const Matrix<ElementType> dataset;
const IndexParams& index_params;
size_t size_;
size_t veclen_;
/*--------------------- Internal Data Structures --------------------------*/
struct Node {
int *ind;
int count;
/**
* Dimension used for subdivision.
*/
int divfeat;
/**
* The values used for subdivision.
*/
DistanceType divlow, divhigh;
/**
* Values indicating the borders of the cell in the splitting dimension
*/
DistanceType lowval, highval;
/**
* The child nodes.
*/
Node *child1, *child2;
};
typedef Node* NodePtr;
struct BoundingBox {
ElementType* low;
ElementType* high;
size_t size;
BoundingBox() {
low = NULL;
high = NULL;
}
~BoundingBox() {
if (low!=NULL) delete[] low;
if (high!=NULL) delete[] high;
}
void computeFromData(const Matrix<ElementType>& data)
{
assert(data.rows>0);
size = data.cols;
low = new ElementType[size];
high = new ElementType[size];
for (size_t i=0;i<size;++i) {
low[i] = data[0][i];
high[i] = data[0][i];
}
for (size_t k=1;k<data.rows;++k) {
for (size_t i=0;i<size;++i) {
if (data[k][i]<low[i]) low[i] = data[k][i];
if (data[k][i]>high[i]) high[i] = data[k][i];
}
}
}
};
/**
* Array of k-d trees used to find neighbours.
*/
NodePtr root_node;
typedef BranchStruct<NodePtr> BranchSt;
typedef BranchSt* Branch;
BoundingBox bbox;
/**
* Pooled memory allocator.
*
* Using a pooled memory allocator is more efficient
* than allocating memory directly when there is a large
* number small of memory allocations.
*/
PooledAllocator pool;
public:
Distance distance;
int count_leaf;
flann_algorithm_t getType() const
{
return KDTREE;
}
/**
* KDTree constructor
*
* Params:
* inputData = dataset with the input features
* params = parameters passed to the kdtree algorithm
*/
KDTreeSimpleIndex(const Matrix<ElementType>& inputData, const KDTreeSimpleIndexParams& params = KDTreeSimpleIndexParams(),
Distance d = Distance() ) :
dataset(inputData), index_params(params), distance(d)
{
size_ = dataset.rows;
veclen_ = dataset.cols;
leaf_max_size_ = params.leaf_max_size;
// Create a permutable array of indices to the input vectors.
vind = new int[size_];
for (size_t i = 0; i < size_; i++) {
vind[i] = i;
}
randomizeVector(vind, size_);
bbox.computeFromData(dataset);
count_leaf = 0;
}
/**
* Standard destructor
*/
~KDTreeSimpleIndex()
{
delete[] vind;
}
template <typename Vector>
void randomizeVector(Vector& vec, int vec_size)
{
for (int j = vec_size; j > 0; --j) {
int rnd = rand_int(j);
swap(vec[j-1], vec[rnd]);
}
}
/**
* Builds the index
*/
void buildIndex()
{
root_node = divideTree(vind, size_ ); // construct the tree
}
void saveIndex(FILE* stream)
{
save_tree(stream, root_node);
}
void loadIndex(FILE* stream)
{
load_tree(stream, root_node);
}
/**
* Returns size of index.
*/
size_t size() const
{
return size_;
}
/**
* Returns the length of an index feature.
*/
size_t veclen() const
{
return veclen_;
}
/**
* Computes the inde memory usage
* Returns: memory used by the index
*/
int usedMemory() const
{
return pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int); // pool memory and vind array memory
}
/**
* Find set of nearest neighbors to vec. Their indices are stored inside
* the result object.
*
* Params:
* result = the result object in which the indices of the nearest-neighbors are stored
* vec = the vector for which to search the nearest neighbors
* maxCheck = the maximum number of restarts (in a best-bin-first manner)
*/
void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams)
{
// int maxChecks = searchParams.checks;
float epsError = 1+searchParams.eps;
float distsq = computeInitialDistance(vec);
searchLevel(result, vec, root_node, distsq, epsError);
}
const IndexParams* getParameters() const
{
return &index_params;
}
private:
void save_tree(FILE* stream, NodePtr tree)
{
save_value(stream, *tree);
if (tree->child1!=NULL) {
save_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
save_tree(stream, tree->child2);
}
}
void load_tree(FILE* stream, NodePtr& tree)
{
tree = pool.allocate<Node>();
load_value(stream, *tree);
if (tree->child1!=NULL) {
load_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
load_tree(stream, tree->child2);
}
}
/**
* Create a tree node that subdivides the list of vecs from vind[first]
* to vind[last]. The routine is called recursively on each sublist.
* Place a pointer to this new tree node in the location pTree.
*
* Params: pTree = the new node to create
* first = index of the first vector
* last = index of the last vector
*/
NodePtr divideTree(int* ind, int count)
{
NodePtr node = pool.allocate<Node>(); // allocate memory
/* If too few exemplars remain, then make this a leaf node. */
if ( count <= leaf_max_size_) {
node->child1 = node->child2 = NULL; /* Mark as leaf node. */
node->ind = ind; /* Store index of this vec. */
node->count = count; /* and length */
}
else {
int idx;
int cutfeat;
DistanceType cutval;
ElementType min_val, max_val;
middleSplit(ind, count, idx, cutfeat, cutval);
node->divfeat = cutfeat;
node->lowval = bbox.low[cutfeat];
node->highval = bbox.high[cutfeat];
computeMinMax(ind, idx, cutfeat, min_val, max_val);
bbox.high[cutfeat] = max_val;
node->divlow = max_val;
node->child1 = divideTree(ind, idx);
bbox.high[cutfeat] = node->highval;
computeMinMax(ind+idx, count-idx, cutfeat, min_val, max_val);
bbox.low[cutfeat] = min_val;
node->divhigh = min_val;
node->child2 = divideTree(ind+idx, count-idx);
bbox.low[cutfeat] = node->lowval;
}
return node;
}
ElementType computeSpead(int* ind, int count, int dim)
{
ElementType min_elem, max_elem;
computeMinMax(ind, count, dim, min_elem, max_elem);
return max_elem-min_elem;
}
void computeMinMax(int* ind, int count, int dim, ElementType& min_elem, ElementType& max_elem)
{
min_elem = dataset[ind[0]][dim];
max_elem = dataset[ind[0]][dim];
for (int i=1;i<count;++i) {
ElementType val = dataset[ind[i]][dim];
if (val<min_elem) min_elem = val;
if (val>max_elem) max_elem = val;
}
}
void middleSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
{
const float EPS=0.00001;
ElementType max_span = bbox.high[0]-bbox.low[0];
for (size_t i=1;i<veclen_;++i) {
ElementType span = bbox.high[i]-bbox.low[i];
if (span>max_span) {
max_span = span;
}
}
ElementType max_spread = -1;
cutfeat = 0;
for (size_t i=0;i<veclen_;++i) {
ElementType span = bbox.high[i]-bbox.low[i];
if (span>(1-EPS)*max_span) {
ElementType spread = computeSpead(ind, count, i);
if (spread>max_spread) {
cutfeat = i;
max_spread = spread;
}
}
}
// split in the middle
DistanceType split_val = (bbox.low[cutfeat]+bbox.high[cutfeat])/2;
ElementType min_elem, max_elem;
computeMinMax(ind, count, cutfeat, min_elem, max_elem);
if (split_val<min_elem) cutval = min_elem;
else if (split_val>max_elem) cutval = max_elem;
else cutval = split_val;
int lim1, lim2;
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
if (lim1>count/2) index = lim1;
else if (lim2<count/2) index = lim2;
else index = count/2;
}
/**
* Subdivide the list of points by a plane perpendicular on axe corresponding
* to the 'cutfeat' dimension at 'cutval' position.
*
* On return:
* dataset[ind[0..lim1-1]][cutfeat]<cutval
* dataset[ind[lim1..lim2-1]][cutfeat]==cutval
* dataset[ind[lim2..count]][cutfeat]>cutval
*/
void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
{
/* Move vector indices for left subtree to front of list. */
int left = 0;
int right = count-1;
for (;;) {
while (left<=right && dataset[ind[left]][cutfeat]<cutval) ++left;
while (left<=right && dataset[ind[right]][cutfeat]>=cutval) --right;
if (left>right) break;
swap(ind[left], ind[right]); ++left; --right;
}
/* If either list is empty, it means that all remaining features
* are identical. Split in the middle to maintain a balanced tree.
*/
lim1 = left;
right = count-1;
for (;;) {
while (left<=right && dataset[ind[left]][cutfeat]<=cutval) ++left;
while (left<=right && dataset[ind[right]][cutfeat]>cutval) --right;
if (left>right) break;
swap(ind[left], ind[right]); ++left; --right;
}
lim2 = left;
}
float computeInitialDistance(const ElementType* vec)
{
float distsq = 0.0;
for (size_t i=0;i<veclen();++i) {
if (vec[i]<bbox.low[i]) distsq += distance.accum_dist(vec[i], bbox.low[i]);
if (vec[i]>bbox.high[i]) distsq += distance.accum_dist(vec[i], bbox.high[i]);
}
return distsq;
}
/**
* Performs an exact search in the tree starting from a node.
*/
void searchLevel(ResultSet& result_set, const ElementType* vec, const NodePtr node, float mindistsq, const float epsError)
{
/* If this is a leaf node, then do check and return. */
if (node->child1 == NULL && node->child2 == NULL) {
count_leaf += node->count;
float worst_dist = result_set.worstDist();
for (int i=0;i<node->count;++i) {
int index = node->ind[i];
float dist = distance(vec, dataset[index], veclen_, worst_dist);
if (dist<worst_dist) {
result_set.addPoint(dist,index);
}
}
return;
}
/* Which child branch should be taken first? */
ElementType val = vec[node->divfeat];
DistanceType diff1 = val - node->divlow;
DistanceType diff2 = val - node->divhigh;
NodePtr bestChild;
NodePtr otherChild;
float cut_dist = 0;
if ((diff1+diff2)<0) {
bestChild = node->child1;
otherChild = node->child2;
cut_dist = distance.accum_dist(val, node->divhigh);
if (val<node->lowval) { // outside of cell, correct distance
cut_dist -= distance.accum_dist(val, node->lowval);
}
}
else {
bestChild = node->child2;
otherChild = node->child1;
cut_dist = distance.accum_dist( val, node->divlow);
if (val>node->highval) { // outside of cell, correct distance
cut_dist -= distance.accum_dist(val, node->highval);
}
}
/* Call recursively to search next level down. */
searchLevel(result_set, vec, bestChild, mindistsq, epsError);
mindistsq = mindistsq + cut_dist;
if (mindistsq*epsError<=result_set.worstDist()) {
searchLevel(result_set, vec, otherChild, mindistsq, epsError);
}
}
}; // class KDTree
}
#endif //KDTREESIMPLE_H
......@@ -58,17 +58,22 @@ struct LinearIndexParams : public IndexParams {
}
};
template <typename ELEM_TYPE, typename DIST_TYPE = typename DistType<ELEM_TYPE>::type >
class LinearIndex : public NNIndex<ELEM_TYPE>
template <typename Distance>
class LinearIndex : public NNIndex<Distance>
{
const Matrix<ELEM_TYPE> dataset;
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
const Matrix<ElementType> dataset;
const LinearIndexParams& index_params;
Distance distance;
public:
LinearIndex(const Matrix<ELEM_TYPE>& inputData, const LinearIndexParams& params = LinearIndexParams() ) :
dataset(inputData), index_params(params)
LinearIndex(const Matrix<ElementType>& inputData, const LinearIndexParams& params = LinearIndexParams(),
Distance d = Distance()) :
dataset(inputData), index_params(params), distance(d)
{
}
......@@ -110,10 +115,11 @@ public:
/* nothing to do here for linear search */
}
void findNeighbors(ResultSet<ELEM_TYPE>& resultSet, const ELEM_TYPE* vec, const SearchParams& searchParams)
void findNeighbors(ResultSet& resultSet, const ElementType* vec, const SearchParams& searchParams)
{
for (size_t i=0;i<dataset.rows;++i) {
resultSet.addPoint(dataset[i],i);
DistanceType dist = distance(dataset[i],vec, dataset.cols);
resultSet.addPoint(dist,i);
}
}
......
......@@ -41,16 +41,16 @@ using namespace std;
namespace flann
{
template <typename ELEM_TYPE>
class ResultSet;
/**
* Nearest-neighbour index base class
*/
template <typename ELEM_TYPE>
template <typename Distance>
class NNIndex
{
typedef typename Distance::ElementType ElementType;
public:
virtual ~NNIndex() {};
......@@ -73,7 +73,7 @@ public:
/**
Method that searches for nearest-neighbors
*/
virtual void findNeighbors(ResultSet<ELEM_TYPE>& result, const ELEM_TYPE* vec, const SearchParams& searchParams) = 0;
virtual void findNeighbors(ResultSet& result, const ElementType* vec, const SearchParams& searchParams) = 0;
/**
Number of features in this index.
......
此差异已折叠。
......@@ -221,7 +221,7 @@ LIBSPEC int flann_find_nearest_neighbors_double(double* dataset,
double* testset,
int trows,
int* indices,
float* dists,
double* dists,
int nn,
struct FLANNParameters* flann_params);
......@@ -280,7 +280,7 @@ LIBSPEC int flann_find_nearest_neighbors_index_double(flann_index_t index_id,
double* testset,
int trows,
int* indices,
float* dists,
double* dists,
int nn,
struct FLANNParameters* flann_params);
......@@ -334,7 +334,7 @@ LIBSPEC int flann_radius_search_float(flann_index_t index_ptr, /* the index */
LIBSPEC int flann_radius_search_double(flann_index_t index_ptr, /* the index */
double* query, /* query point */
int* indices, /* array for storing the indices found (will be modified) */
float* dists, /* similar, but for storing distances */
double* dists, /* similar, but for storing distances */
int max_nn, /* size of arrays indices and dists */
float radius, /* search radius (squared radius for euclidian metric) */
struct FLANNParameters* flann_params);
......
......@@ -59,14 +59,6 @@ Params:
void log_verbosity(int level);
/**
* Sets the distance type to use throughout FLANN.
* If distance type specified is MINKOWSKI, the second argument
* specifies which order the minkowski distance should have.
*/
void set_distance_type(flann_distance_t distance_type, int order);
struct SavedIndexParams : public IndexParams {
SavedIndexParams(std::string filename_) : IndexParams(SAVED), filename(filename_) {}
......@@ -93,21 +85,24 @@ struct SavedIndexParams : public IndexParams {
}
};
template<typename T>
template<typename Distance>
class Index {
NNIndex<T>* nnIndex;
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
Distance distance;
NNIndex<Distance>* nnIndex;
bool built;
public:
Index(const Matrix<T>& features, const IndexParams& params);
Index(const Matrix<ElementType>& features, const IndexParams& params, Distance d = Distance() );
~Index();
void buildIndex();
void knnSearch(const Matrix<T>& queries, Matrix<int>& indices, Matrix<float>& dists, int knn, const SearchParams& params);
void knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& params);
int radiusSearch(const Matrix<T>& query, Matrix<int>& indices, Matrix<float>& dists, float radius, const SearchParams& params);
int radiusSearch(const Matrix<ElementType>& query, Matrix<int>& indices, Matrix<DistanceType>& dists, float radius, const SearchParams& params);
void save(std::string filename);
......@@ -115,21 +110,23 @@ public:
int size() const;
NNIndex<T>* getIndex() { return nnIndex; }
NNIndex<Distance>* getIndex() { return nnIndex; }
const IndexParams* getIndexParameters() { return nnIndex->getParameters(); }
};
template<typename T>
NNIndex<T>* load_saved_index(const Matrix<T>& dataset, const string& filename)
template<typename Distance>
NNIndex<Distance>* load_saved_index(const Matrix<typename Distance::ElementType>& dataset, const string& filename, Distance distance)
{
typedef typename Distance::ElementType ElementType;
FILE* fin = fopen(filename.c_str(), "rb");
if (fin==NULL) {
return NULL;
}
IndexHeader header = load_header(fin);
if (header.data_type!=get_flann_datatype<T>()) {
if (header.data_type!=get_flann_datatype<ElementType>()) {
throw FLANNException("Datatype of saved index is different than of the one to be created.");
}
if (size_t(header.rows)!=dataset.rows || size_t(header.cols)!=dataset.cols) {
......@@ -137,7 +134,7 @@ NNIndex<T>* load_saved_index(const Matrix<T>& dataset, const string& filename)
}
IndexParams* params = ParamsFactory::instance().create(header.index_type);
NNIndex<T>* nnIndex = create_index_by_type(dataset, *params);
NNIndex<Distance>* nnIndex = create_index_by_type<Distance>(dataset, *params, distance);
nnIndex->loadIndex(fin);
fclose(fin);
......@@ -145,29 +142,29 @@ NNIndex<T>* load_saved_index(const Matrix<T>& dataset, const string& filename)
}
template<typename T>
Index<T>::Index(const Matrix<T>& dataset, const IndexParams& params)
template<typename Distance>
Index<Distance>::Index(const Matrix<ElementType>& dataset, const IndexParams& params, Distance d ) : distance (d)
{
flann_algorithm_t index_type = params.getIndexType();
built = false;
if (index_type==SAVED) {
nnIndex = load_saved_index(dataset, ((const SavedIndexParams&)params).filename);
nnIndex = load_saved_index<Distance>(dataset, ((const SavedIndexParams&)params).filename, distance);
built = true;
}
else {
nnIndex = create_index_by_type(dataset, params);
nnIndex = create_index_by_type<Distance>(dataset, params, distance);
}
}
template<typename T>
Index<T>::~Index()
template<typename Distance>
Index<Distance>::~Index()
{
delete nnIndex;
}
template<typename T>
void Index<T>::buildIndex()
template<typename Distance>
void Index<Distance>::buildIndex()
{
if (!built) {
nnIndex->buildIndex();
......@@ -175,8 +172,8 @@ void Index<T>::buildIndex()
}
}
template<typename T>
void Index<T>::knnSearch(const Matrix<T>& queries, Matrix<int>& indices, Matrix<float>& dists, int knn, const SearchParams& searchParams)
template<typename Distance>
void Index<Distance>::knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& searchParams)
{
if (!built) {
throw FLANNException("You must build the index before searching.");
......@@ -187,11 +184,11 @@ void Index<T>::knnSearch(const Matrix<T>& queries, Matrix<int>& indices, Matrix<
assert(int(indices.cols)>=knn);
assert(int(dists.cols)>=knn);
KNNResultSet<T> resultSet(knn);
KNNResultSet resultSet(knn);
for (size_t i = 0; i < queries.rows; i++) {
T* target = queries[i];
resultSet.init(target, queries.cols);
ElementType* target = queries[i];
resultSet.init();
nnIndex->findNeighbors(resultSet, target, searchParams);
......@@ -202,8 +199,8 @@ void Index<T>::knnSearch(const Matrix<T>& queries, Matrix<int>& indices, Matrix<
}
}
template<typename T>
int Index<T>::radiusSearch(const Matrix<T>& query, Matrix<int>& indices, Matrix<float>& dists, float radius, const SearchParams& searchParams)
template<typename Distance>
int Index<Distance>::radiusSearch(const Matrix<ElementType>& query, Matrix<int>& indices, Matrix<DistanceType>& dists, float radius, const SearchParams& searchParams)
{
if (!built) {
throw FLANNException("You must build the index before searching.");
......@@ -214,9 +211,9 @@ int Index<T>::radiusSearch(const Matrix<T>& query, Matrix<int>& indices, Matrix<
}
assert(query.cols==nnIndex->veclen());
RadiusResultSet<T> resultSet(radius);
resultSet.init(query.data, query.cols);
nnIndex->findNeighbors(resultSet,query.data,searchParams);
RadiusResultSet resultSet(radius);
resultSet.init();
nnIndex->findNeighbors(resultSet, query[0] ,searchParams);
// TODO: optimise here
int* neighbors = resultSet.getNeighbors();
......@@ -234,8 +231,8 @@ int Index<T>::radiusSearch(const Matrix<T>& query, Matrix<int>& indices, Matrix<
}
template<typename T>
void Index<T>::save(string filename)
template<typename Distance>
void Index<Distance>::save(string filename)
{
FILE* fout = fopen(filename.c_str(), "wb");
if (fout==NULL) {
......@@ -247,23 +244,24 @@ void Index<T>::save(string filename)
}
template<typename T>
int Index<T>::size() const
template<typename Distance>
int Index<Distance>::size() const
{
return nnIndex->size();
}
template<typename T>
int Index<T>::veclen() const
template<typename Distance>
int Index<Distance>::veclen() const
{
return nnIndex->veclen();
}
template <typename ELEM_TYPE, typename DIST_TYPE>
int hierarchicalClustering(const Matrix<ELEM_TYPE>& features, Matrix<DIST_TYPE>& centers, const KMeansIndexParams& params)
template <typename Distance>
int hierarchicalClustering(const Matrix<typename Distance::ElementType>& features, Matrix<typename Distance::ResultType>& centers,
const KMeansIndexParams& params, Distance d)
{
KMeansIndex<ELEM_TYPE> kmeans(features, params);
KMeansIndex<Distance> kmeans(features, params, d);
kmeans.buildIndex();
int clusterNum = kmeans.getClusterCenters(centers);
......
......@@ -37,8 +37,8 @@
#include "flann/util/saving.h"
#include "flann/nn/ground_truth.h"
// index types
#include "flann/algorithms/kdtree_index2.h"
#include "flann/algorithms/kdtree_index.h"
#include "flann/algorithms/kdtree_simple_index.h"
#include "flann/algorithms/kmeans_index.h"
#include "flann/algorithms/composite_index.h"
#include "flann/algorithms/linear_index.h"
......@@ -62,11 +62,6 @@ void log_verbosity(int level)
}
}
void set_distance_type(flann_distance_t distance_type, int order)
{
flann_distance_type = distance_type;
flann_minkowski_order = order;
}
IndexParams* IndexParams::createFromParameters(const FLANNParameters& p)
{
......@@ -84,7 +79,7 @@ public:
{
ParamsFactory::instance().register_<LinearIndexParams>(LINEAR);
ParamsFactory::instance().register_<KDTreeIndexParams>(KDTREE);
ParamsFactory::instance().register_<KDTreeIndex2Params>(KDTREE2);
ParamsFactory::instance().register_<KDTreeSimpleIndexParams>(KDTREE_SIMPLE);
ParamsFactory::instance().register_<KMeansIndexParams>(KMEANS);
ParamsFactory::instance().register_<CompositeIndexParams>(COMPOSITE);
ParamsFactory::instance().register_<AutotunedIndexParams>(AUTOTUNED);
......
......@@ -216,6 +216,8 @@ void Index<T>::knnSearch(const flann::Matrix<T>& queries, flann::Matrix<int>& in
template<typename T>
int Index<T>::radiusSearch(const flann::Matrix<T>& query, flann::Matrix<int>& indices, flann::Matrix<float>& dists, float radius, const SearchParams& params)
{
// TODO: fix this
// mpi::communicator world;
// flann::Matrix<int> local_indices(new int[indices.rows*indices.cols],indices.rows, indices.cols);
// flann::Matrix<float> local_dists(new float[dists.rows*dists.cols],dists.rows, dists.cols);
......
......@@ -41,7 +41,7 @@ enum flann_algorithm_t {
KDTREE = 1,
KMEANS = 2,
COMPOSITE = 3,
KDTREE2 = 4,
KDTREE_SIMPLE = 4,
SAVED = 254,
AUTOTUNED = 255
};
......@@ -65,7 +65,7 @@ enum flann_distance_t {
MANHATTAN = 2,
MINKOWSKI = 3,
MAX_DIST = 4,
HIK = 5,
HIST_INTERSECT = 5,
HELLINGER = 6,
CS = 7,
CHI_SQUARE = 7,
......@@ -74,16 +74,16 @@ enum flann_distance_t {
};
enum flann_datatype_t {
INT8 = 0,
INT16 = 1,
INT32 = 2,
INT64 = 3,
UINT8 = 4,
UINT16 = 5,
UINT32 = 6,
UINT64 = 7,
FLOAT32 = 8,
FLOAT64 = 9
FLANN_INT8 = 0,
FLANN_INT16 = 1,
FLANN_INT32 = 2,
FLANN_INT64 = 3,
FLANN_UINT8 = 4,
FLANN_UINT16 = 5,
FLANN_UINT32 = 6,
FLANN_UINT64 = 7,
FLANN_FLOAT32 = 8,
FLANN_FLOAT64 = 9
};
......@@ -97,6 +97,7 @@ struct FLANNParameters {
/* kdtree index parameters */
int trees; /* number of randomized trees to use (for kdtree) */
int leaf_max_size;
/* kmeans index parameters */
int branching; /* branching factor (for kmeans tree) */
......@@ -124,26 +125,6 @@ struct FLANNParameters {
namespace flann {
template <typename ELEM_TYPE>
struct DistType
{
typedef ELEM_TYPE type;
};
template <>
struct DistType<unsigned char>
{
typedef float type;
};
template <>
struct DistType<int>
{
typedef float type;
};
class FLANNException : public std::runtime_error {
public:
FLANNException(const char* message) : std::runtime_error(message) { }
......@@ -159,7 +140,7 @@ protected:
public:
static IndexParams* createFromParameters(const FLANNParameters& p);
virtual flann_algorithm_t getIndexType() const = 0;
virtual flann_algorithm_t getIndexType() const { return algorithm; };
virtual void fromParameters(const FLANNParameters& p) = 0;
virtual void toParameters(FLANNParameters& p) const = 0;
......
......@@ -104,11 +104,9 @@ void load_from_file(flann::Matrix<T>& dataset, const std::string& filename, cons
hsize_t dims_out[2];
H5Sget_simple_extent_dims(space_id, dims_out, NULL);
dataset.rows = dims_out[0];
dataset.cols = dims_out[1];
dataset.data = new T[dataset.rows*dataset.cols];
dataset = flann::Matrix<T>(new T[dims_out[0]*dims_out[1]], dims_out[0], dims_out[1]);
status = H5Dread(dataset_id, get_hdf5_type<T>(), H5S_ALL, H5S_ALL, H5P_DEFAULT, dataset.data);
status = H5Dread(dataset_id, get_hdf5_type<T>(), H5S_ALL, H5S_ALL, H5P_DEFAULT, dataset[0]);
CHECK_ERROR(status, "Error reading dataset");
H5Sclose(space_id);
......
......@@ -38,22 +38,22 @@
namespace flann
{
template <typename T>
void find_nearest(const Matrix<T>& dataset, T* query, int* matches, int nn, int skip = 0)
template <typename Distance>
void find_nearest(const Matrix<typename Distance::ElementType>& dataset, typename Distance::ElementType* query, int* matches, int nn,
int skip = 0, Distance distance = Distance())
{
typedef typename Distance::ElementType ElementType;
int n = nn + skip;
T* query_end = query + dataset.cols;
int* match = new int[n];
ElementType* dists = new ElementType[n];
long* match = new long[n];
T* dists = new T[n];
dists[0] = flann_dist(query, query_end, dataset[0]);
dists[0] = distance(dataset[0], query, dataset.cols);
match[0] = 0;
int dcnt = 1;
for (size_t i=1;i<dataset.rows;++i) {
T tmp = flann_dist(query, query_end, dataset[i]);
ElementType tmp = distance(dataset[i], query, dataset.cols);
if (dcnt<n) {
match[dcnt] = i;
......@@ -82,11 +82,12 @@ void find_nearest(const Matrix<T>& dataset, T* query, int* matches, int nn, int
}
template <typename T>
void compute_ground_truth(const Matrix<T>& dataset, const Matrix<T>& testset, Matrix<int>& matches, int skip=0)
template <typename Distance>
void compute_ground_truth(const Matrix<typename Distance::ElementType>& dataset, const Matrix<typename Distance::ElementType>& testset, Matrix<int>& matches,
int skip=0, Distance d = Distance())
{
for (size_t i=0;i<testset.rows;++i) {
find_nearest(dataset, testset[i], matches[i], matches.cols, skip);
find_nearest<Distance>(dataset, testset[i], matches[i], matches.cols, skip, d);
}
}
......
......@@ -33,6 +33,7 @@
#include <cstring>
#include <cassert>
#include <cmath>
#include "flann/util/matrix.h"
#include "flann/algorithms/nn_index.h"
......@@ -49,14 +50,14 @@ namespace flann
int countCorrectMatches(int* neighbors, int* groundTruth, int n);
template <typename ELEM_TYPE>
float computeDistanceRaport(const Matrix<ELEM_TYPE>& inputData, ELEM_TYPE* target, int* neighbors, int* groundTruth, int veclen, int n)
template <typename Distance>
float computeDistanceRaport(const Matrix<typename Distance::ElementType>& inputData, typename Distance::ElementType* target,
int* neighbors, int* groundTruth, int veclen, int n, Distance distance = Distance() )
{
ELEM_TYPE* target_end = target + veclen;
float ret = 0;
for (int i=0;i<n;++i) {
float den = flann_dist(target,target_end, inputData[groundTruth[i]]);
float num = flann_dist(target,target_end, inputData[neighbors[i]]);
float den = distance(inputData[groundTruth[i]], target, veclen);
float num = distance(inputData[neighbors[i]], target, veclen);
if (den==0 && num==0) {
ret += 1;
......@@ -69,8 +70,9 @@ float computeDistanceRaport(const Matrix<ELEM_TYPE>& inputData, ELEM_TYPE* targe
return ret;
}
template <typename ELEM_TYPE>
float search_with_ground_truth(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE>& inputData, const Matrix<ELEM_TYPE>& testData, const Matrix<int>& matches, int nn, int checks, float& time, float& dist, int skipMatches)
template <typename Distance>
float search_with_ground_truth(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches, int nn, int checks, float& time, float& dist, int skipMatches)
{
if (matches.cols<size_t(nn)) {
logger.info("matches.cols=%d, nn=%d\n",matches.cols,nn);
......@@ -78,7 +80,7 @@ float search_with_ground_truth(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE
throw FLANNException("Ground truth is not computed for as many neighbors as requested");
}
KNNResultSet<ELEM_TYPE> resultSet(nn+skipMatches);
KNNResultSet resultSet(nn+skipMatches);
SearchParams searchParams(checks);
int correct;
......@@ -91,14 +93,14 @@ float search_with_ground_truth(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE
correct = 0;
distR = 0;
for (size_t i = 0; i < testData.rows; i++) {
ELEM_TYPE* target = testData[i];
resultSet.init(target, testData.cols);
typename Distance::ElementType* target = testData[i];
resultSet.init();
index.findNeighbors(resultSet,target, searchParams);
int* neighbors = resultSet.getNeighbors();
neighbors = neighbors+skipMatches;
correct += countCorrectMatches(neighbors,matches[i], nn);
distR += computeDistanceRaport(inputData, target,neighbors,matches[i], testData.cols, nn);
distR += computeDistanceRaport<Distance>(inputData, target, neighbors, matches[i], testData.cols, nn);
}
t.stop();
}
......@@ -116,8 +118,9 @@ float search_with_ground_truth(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE
}
template <typename ELEM_TYPE>
float test_index_checks(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE>& inputData, const Matrix<ELEM_TYPE>& testData, const Matrix<int>& matches,
template <typename Distance>
float test_index_checks(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches,
int checks, float& precision, int nn = 1, int skipMatches = 0)
{
logger.info(" Nodes Precision(%) Time(s) Time/vec(ms) Mean dist\n");
......@@ -130,8 +133,9 @@ float test_index_checks(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE>& inpu
return time;
}
template <typename ELEM_TYPE>
float test_index_precision(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE>& inputData, const Matrix<ELEM_TYPE>& testData, const Matrix<int>& matches,
template <typename Distance>
float test_index_precision(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches,
float precision, int& checks, int nn = 1, int skipMatches = 0)
{
const float SEARCH_EPS = 0.001;
......@@ -200,8 +204,9 @@ float test_index_precision(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE>& i
}
template <typename ELEM_TYPE>
float test_index_precisions(NNIndex<ELEM_TYPE>& index, const Matrix<ELEM_TYPE>& inputData, const Matrix<ELEM_TYPE>& testData, const Matrix<int>& matches,
template <typename Distance>
float test_index_precisions(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches,
float* precisions, int precisions_length, int nn = 1, int skipMatches = 0, float maxTime = 0)
{
const float SEARCH_EPS = 0.001;
......
此差异已折叠。
......@@ -33,14 +33,14 @@
namespace flann
{
template<> flann_datatype_t get_flann_datatype<char>() { return INT8; }
template<> flann_datatype_t get_flann_datatype<short>() { return INT16; }
template<> flann_datatype_t get_flann_datatype<int>() { return INT32; }
template<> flann_datatype_t get_flann_datatype<unsigned char>() { return UINT8; }
template<> flann_datatype_t get_flann_datatype<unsigned short>() { return UINT16; }
template<> flann_datatype_t get_flann_datatype<unsigned int>() { return UINT32; }
template<> flann_datatype_t get_flann_datatype<float>() { return FLOAT32; }
template<> flann_datatype_t get_flann_datatype<double>() { return FLOAT64; }
template<> flann_datatype_t get_flann_datatype<char>() { return FLANN_INT8; }
template<> flann_datatype_t get_flann_datatype<short>() { return FLANN_INT16; }
template<> flann_datatype_t get_flann_datatype<int>() { return FLANN_INT32; }
template<> flann_datatype_t get_flann_datatype<unsigned char>() { return FLANN_UINT8; }
template<> flann_datatype_t get_flann_datatype<unsigned short>() { return FLANN_UINT16; }
template<> flann_datatype_t get_flann_datatype<unsigned int>() { return FLANN_UINT32; }
template<> flann_datatype_t get_flann_datatype<float>() { return FLANN_FLOAT32; }
template<> flann_datatype_t get_flann_datatype<double>() { return FLANN_FLOAT64; }
const char FLANN_SIGNATURE[] = "FLANN_INDEX";
......
此差异已折叠。
此差异已折叠。
add_custom_target(tests)
add_custom_target(test)
add_dependencies(test tests)
set(EXECUTABLE_OUTPUT_PATH ${TEST_OUTPUT_PATH})
#add_executable(flann_mt_test flann_mt_test.cpp)
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册