diff --git a/src/python/pyflann/exceptions.py b/src/python/pyflann/exceptions.py index 8be34c1b1b26335926b37b653d7ab694adcae155..361b5e7b792d592d64a3a0b69fdcbff4d2b0aec2 100644 --- a/src/python/pyflann/exceptions.py +++ b/src/python/pyflann/exceptions.py @@ -30,5 +30,3 @@ class FLANNException(Exception): def __init__(self, *args): Exception.__init__(self, *args) - - diff --git a/src/python/pyflann/flann_ctypes.py b/src/python/pyflann/flann_ctypes.py index b060f5101378320b98c3c47271fa59610b82445d..b851587d4d5e9230e8c4c80996f7022999ca32fb 100644 --- a/src/python/pyflann/flann_ctypes.py +++ b/src/python/pyflann/flann_ctypes.py @@ -24,10 +24,14 @@ #(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF #THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from ctypes import * +#from ctypes import * #from ctypes.util import find_library -from numpy import float32, float64, uint8, int32, matrix, array, empty, reshape, require -from numpy.ctypeslib import load_library, ndpointer +from numpy import (float32, float64, uint8, int32, require) +#import ctypes +#import numpy as np +from ctypes import (Structure, c_char_p, c_int, c_float, c_uint, c_long, + c_void_p, cdll, POINTER) +from numpy.ctypeslib import ndpointer import os import sys @@ -42,44 +46,45 @@ class CustomStructure(Structure): """ _defaults_ = {} _translation_ = {} - + def __init__(self): Structure.__init__(self) - self.__field_names = [ f for (f,t) in self._fields_] - self.update(self._defaults_) - + self.__field_names = [ f for (f, t) in self._fields_] + self.update(self._defaults_) + def update(self, dict): - for k,v in dict.items(): + for k, v in dict.items(): if k in self.__field_names: - setattr(self,k,self.__translate(k,v)) + setattr(self, k, self.__translate(k, v)) else: - raise KeyError("No such member: "+k) - + raise KeyError('No such member: ' + k) + def __getitem__(self, k): if k in self.__field_names: - return self.__translate_back(k,getattr(self,k)) - + return self.__translate_back(k, getattr(self, k)) + def __setitem__(self, k, v): if k in self.__field_names: - setattr(self,k,self.__translate(k,v)) + setattr(self, k, self.__translate(k, v)) else: - raise KeyError("No such member: "+k) - + raise KeyError('No such member: ' + k) + def keys(self): - return self.__field_names + return self.__field_names - def __translate(self,k,v): + def __translate(self, k, v): if k in self._translation_: if v in self._translation_[k]: return self._translation_[k][v] - return v + return v - def __translate_back(self,k,v): + def __translate_back(self, k, v): if k in self._translation_: - for tk,tv in self._translation_[k].items(): - if tv==v: + for tk, tv in self._translation_[k].items(): + if tv == v: return tk - return v + return v + class FLANNParameters(CustomStructure): _fields_ = [ @@ -125,25 +130,26 @@ class FLANNParameters(CustomStructure): 'table_number_': 12, 'key_size_': 20, 'multi_probe_level_': 2, - 'log_level' : "warning", + 'log_level' : 'warning', 'random_seed' : -1 - } + } _translation_ = { - "algorithm" : {"linear" : 0, "kdtree" : 1, "kmeans" : 2, "composite" : 3, "kdtree_single" : 4, "hierarchical": 5, "lsh": 6, "saved": 254, "autotuned" : 255, "default" : 1}, - "centers_init" : {"random" : 0, "gonzales" : 1, "kmeanspp" : 2, "default" : 0}, - "log_level" : {"none" : 0, "fatal" : 1, "error" : 2, "warning" : 3, "info" : 4, "default" : 2} + 'algorithm' : {'linear' : 0, 'kdtree' : 1, 'kmeans' : 2, 'composite' : 3, 'kdtree_single' : 4, 'hierarchical': 5, 'lsh': 6, 'saved': 254, 'autotuned' : 255, 'default' : 1}, + 'centers_init' : {'random' : 0, 'gonzales' : 1, 'kmeanspp' : 2, 'default' : 0}, + 'log_level' : {'none' : 0, 'fatal' : 1, 'error' : 2, 'warning' : 3, 'info' : 4, 'default' : 2} } - - + + default_flags = ['C_CONTIGUOUS', 'ALIGNED'] -allowed_types = [ float32, float64, uint8, int32] +allowed_types = [ float32, float64, uint8, int32] FLANN_INDEX = c_void_p + def load_flann_library(): root_dir = os.path.abspath(os.path.dirname(__file__)) - + libnames = ['libflann.so'] libdir = 'lib' if sys.platform == 'win32': @@ -151,16 +157,16 @@ def load_flann_library(): elif sys.platform == 'darwin': libnames = ['libflann.dylib'] - while root_dir!=None: + while root_dir is not None: for libname in libnames: try: - #print "Trying ",os.path.join(root_dir,'lib',libname) - flannlib = cdll[os.path.join(root_dir,libdir,libname)] + #print 'Trying ',os.path.join(root_dir,'lib',libname) + flannlib = cdll[os.path.join(root_dir, libdir, libname)] return flannlib except Exception: pass try: - flannlib = cdll[os.path.join(root_dir,"build",libdir,libname)] + flannlib = cdll[os.path.join(root_dir, 'build', libdir, libname)] return flannlib except Exception: pass @@ -174,8 +180,8 @@ def load_flann_library(): # a full path as a last resort for libname in libnames: try: - #print "Trying",libname - flannlib=cdll[libname] + #print 'Trying',libname + flannlib = cdll[libname] return flannlib except: pass @@ -183,43 +189,46 @@ def load_flann_library(): return None flannlib = load_flann_library() -if flannlib == None: +if flannlib is None: raise ImportError('Cannot load dynamic library. Did you compile FLANN?') -class FlannLib: pass + +class FlannLib(object): + pass + flann = FlannLib() flannlib.flann_log_verbosity.restype = None -flannlib.flann_log_verbosity.argtypes = [ - c_int # level +flannlib.flann_log_verbosity.argtypes = [ + c_int # level ] - flannlib.flann_set_distance_type.restype = None -flannlib.flann_set_distance_type.argtypes = [ - c_int, - c_int, +flannlib.flann_set_distance_type.argtypes = [ + c_int, + c_int, ] -type_mappings = ( ('float','float32'), - ('double','float64'), - ('byte','uint8'), - ('int','int32') ) +type_mappings = ( ('float', 'float32'), + ('double', 'float64'), + ('byte', 'uint8'), + ('int', 'int32') ) + def define_functions(str): for type in type_mappings: - eval(compile(str%{'C':type[0],'numpy':type[1]},"","exec")) + eval(compile(str % {'C': type[0], 'numpy': type[1]}, '', 'exec')) flann.build_index = {} define_functions(r""" flannlib.flann_build_index_%(C)s.restype = FLANN_INDEX -flannlib.flann_build_index_%(C)s.argtypes = [ - ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset - c_int, # rows - c_int, # cols - POINTER(c_float), # speedup +flannlib.flann_build_index_%(C)s.argtypes = [ + ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset + c_int, # rows + c_int, # cols + POINTER(c_float), # speedup POINTER(FLANNParameters) # flann_params ] flann.build_index[%(numpy)s] = flannlib.flann_build_index_%(C)s @@ -229,9 +238,9 @@ flann.save_index = {} define_functions(r""" flannlib.flann_save_index_%(C)s.restype = None flannlib.flann_save_index_%(C)s.argtypes = [ - FLANN_INDEX, # index_id - c_char_p #filename -] + FLANN_INDEX, # index_id + c_char_p #filename +] flann.save_index[%(numpy)s] = flannlib.flann_save_index_%(C)s """) @@ -239,26 +248,26 @@ flann.load_index = {} define_functions(r""" flannlib.flann_load_index_%(C)s.restype = FLANN_INDEX flannlib.flann_load_index_%(C)s.argtypes = [ - c_char_p, #filename - ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset - c_int, # rows - c_int, # cols + c_char_p, #filename + ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset + c_int, # rows + c_int, # cols ] flann.load_index[%(numpy)s] = flannlib.flann_load_index_%(C)s """) -flann.find_nearest_neighbors = {} -define_functions(r""" +flann.find_nearest_neighbors = {} +define_functions(r""" flannlib.flann_find_nearest_neighbors_%(C)s.restype = c_int -flannlib.flann_find_nearest_neighbors_%(C)s.argtypes = [ - ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset - c_int, # rows - c_int, # cols - ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # testset +flannlib.flann_find_nearest_neighbors_%(C)s.argtypes = [ + ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset + c_int, # rows + c_int, # cols + ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # testset c_int, # tcount - ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result - ndpointer(float32, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists - c_int, # nn + ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result + ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'), # dists + c_int, # nn POINTER(FLANNParameters) # flann_params ] flann.find_nearest_neighbors[%(numpy)s] = flannlib.flann_find_nearest_neighbors_%(C)s @@ -267,16 +276,16 @@ flann.find_nearest_neighbors[%(numpy)s] = flannlib.flann_find_nearest_neighbors_ # fix definition for the 'double' case flannlib.flann_find_nearest_neighbors_double.restype = c_int -flannlib.flann_find_nearest_neighbors_double.argtypes = [ - ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # dataset - c_int, # rows - c_int, # cols - ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # testset - c_int, # tcount - ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result - ndpointer(float64, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists - c_int, # nn - POINTER(FLANNParameters) # flann_params +flannlib.flann_find_nearest_neighbors_double.argtypes = [ + ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # dataset + c_int, # rows + c_int, # cols + ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # testset + c_int, # tcount + ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result + ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'), # dists + c_int, # nn + POINTER(FLANNParameters) # flann_params ] flann.find_nearest_neighbors[float64] = flannlib.flann_find_nearest_neighbors_double @@ -284,54 +293,54 @@ flann.find_nearest_neighbors[float64] = flannlib.flann_find_nearest_neighbors_do flann.find_nearest_neighbors_index = {} define_functions(r""" flannlib.flann_find_nearest_neighbors_index_%(C)s.restype = c_int -flannlib.flann_find_nearest_neighbors_index_%(C)s.argtypes = [ - FLANN_INDEX, # index_id - ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # testset +flannlib.flann_find_nearest_neighbors_index_%(C)s.argtypes = [ + FLANN_INDEX, # index_id + ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # testset c_int, # tcount - ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result - ndpointer(float32, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists - c_int, # nn + ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result + ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'), # dists + c_int, # nn POINTER(FLANNParameters) # flann_params ] flann.find_nearest_neighbors_index[%(numpy)s] = flannlib.flann_find_nearest_neighbors_index_%(C)s """) flannlib.flann_find_nearest_neighbors_index_double.restype = c_int -flannlib.flann_find_nearest_neighbors_index_double.argtypes = [ - FLANN_INDEX, # index_id - ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # testset - c_int, # tcount - ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result - ndpointer(float64, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists - c_int, # nn - POINTER(FLANNParameters) # flann_params +flannlib.flann_find_nearest_neighbors_index_double.argtypes = [ + FLANN_INDEX, # index_id + ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # testset + c_int, # tcount + ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result + ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'), # dists + c_int, # nn + POINTER(FLANNParameters) # flann_params ] flann.find_nearest_neighbors_index[float64] = flannlib.flann_find_nearest_neighbors_index_double flann.radius_search = {} define_functions(r""" flannlib.flann_radius_search_%(C)s.restype = c_int -flannlib.flann_radius_search_%(C)s.argtypes = [ - FLANN_INDEX, # index_id - ndpointer(%(numpy)s, ndim = 1, flags='aligned, c_contiguous'), # query - ndpointer(int32, ndim = 1, flags='aligned, c_contiguous, writeable'), # indices - ndpointer(float32, ndim = 1, flags='aligned, c_contiguous, writeable'), # dists - c_int, # max_nn - c_float, # radius +flannlib.flann_radius_search_%(C)s.argtypes = [ + FLANN_INDEX, # index_id + ndpointer(%(numpy)s, ndim=1, flags='aligned, c_contiguous'), # query + ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'), # indices + ndpointer(float32, ndim=1, flags='aligned, c_contiguous, writeable'), # dists + c_int, # max_nn + c_float, # radius POINTER(FLANNParameters) # flann_params ] flann.radius_search[%(numpy)s] = flannlib.flann_radius_search_%(C)s """) flannlib.flann_radius_search_double.restype = c_int -flannlib.flann_radius_search_double.argtypes = [ - FLANN_INDEX, # index_id - ndpointer(float64, ndim = 1, flags='aligned, c_contiguous'), # query - ndpointer(int32, ndim = 1, flags='aligned, c_contiguous, writeable'), # indices - ndpointer(float64, ndim = 1, flags='aligned, c_contiguous, writeable'), # dists - c_int, # max_nn - c_float, # radius - POINTER(FLANNParameters) # flann_params +flannlib.flann_radius_search_double.argtypes = [ + FLANN_INDEX, # index_id + ndpointer(float64, ndim=1, flags='aligned, c_contiguous'), # query + ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'), # indices + ndpointer(float64, ndim=1, flags='aligned, c_contiguous, writeable'), # dists + c_int, # max_nn + c_float, # radius + POINTER(FLANNParameters) # flann_params ] flann.radius_search[float64] = flannlib.flann_radius_search_double @@ -339,25 +348,25 @@ flann.radius_search[float64] = flannlib.flann_radius_search_double flann.compute_cluster_centers = {} define_functions(r""" flannlib.flann_compute_cluster_centers_%(C)s.restype = c_int -flannlib.flann_compute_cluster_centers_%(C)s.argtypes = [ - ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset +flannlib.flann_compute_cluster_centers_%(C)s.argtypes = [ + ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset c_int, # rows c_int, # cols - c_int, # clusters - ndpointer(float32, flags='aligned, c_contiguous, writeable'), # result + c_int, # clusters + ndpointer(float32, flags='aligned, c_contiguous, writeable'), # result POINTER(FLANNParameters) # flann_params ] flann.compute_cluster_centers[%(numpy)s] = flannlib.flann_compute_cluster_centers_%(C)s """) # double is an exception flannlib.flann_compute_cluster_centers_double.restype = c_int -flannlib.flann_compute_cluster_centers_double.argtypes = [ - ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # dataset - c_int, # rows - c_int, # cols - c_int, # clusters - ndpointer(float64, flags='aligned, c_contiguous, writeable'), # result - POINTER(FLANNParameters) # flann_params +flannlib.flann_compute_cluster_centers_double.argtypes = [ + ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # dataset + c_int, # rows + c_int, # cols + c_int, # clusters + ndpointer(float64, flags='aligned, c_contiguous, writeable'), # result + POINTER(FLANNParameters) # flann_params ] flann.compute_cluster_centers[float64] = flannlib.flann_compute_cluster_centers_double @@ -365,7 +374,7 @@ flann.compute_cluster_centers[float64] = flannlib.flann_compute_cluster_centers_ flann.free_index = {} define_functions(r""" flannlib.flann_free_index_%(C)s.restype = None -flannlib.flann_free_index_%(C)s.argtypes = [ +flannlib.flann_free_index_%(C)s.argtypes = [ FLANN_INDEX, # index_id POINTER(FLANNParameters) # flann_params ] @@ -373,8 +382,8 @@ flann.free_index[%(numpy)s] = flannlib.flann_free_index_%(C)s """) -def ensure_2d_array(array, flags, **kwargs): - array = require(array, requirements = flags, **kwargs) - if len(array.shape) == 1: - array = array.reshape(-1,array.size) - return array +def ensure_2d_array(arr, flags, **kwargs): + arr = require(arr, requirements=flags, **kwargs) + if len(arr.shape) == 1: + arr = arr.reshape(-1, arr.size) + return arr diff --git a/src/python/pyflann/index.py b/src/python/pyflann/index.py index a587f34eb78cfbf7aa7492c70068d05981da28ab..6c984f4c167e56c17b94b47ad9efb21b74746e61 100644 --- a/src/python/pyflann/index.py +++ b/src/python/pyflann/index.py @@ -24,50 +24,59 @@ #(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF #THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from pyflann.flann_ctypes import * -from pyflann.exceptions import * +#from pyflann.flann_ctypes import * # NOQA +import sys +from ctypes import pointer, c_float, byref, c_char_p +from pyflann.flann_ctypes import (flannlib, FLANNParameters, allowed_types, + ensure_2d_array, default_flags, flann) +import numpy as np + +from pyflann.exceptions import FLANNException import numpy.random as _rn -index_type = int32 +index_type = np.int32 -def set_distance_type(distance_type, order = 0): + +def set_distance_type(distance_type, order=0): """ - Sets the distance type used. Possible values: euclidean, manhattan, minkowski, max_dist, + Sets the distance type used. Possible values: euclidean, manhattan, minkowski, max_dist, hik, hellinger, cs, kl. """ - - distance_translation = { "euclidean" : 1, - "manhattan" : 2, - "minkowski" : 3, - "max_dist" : 4, - "hik" : 5, - "hellinger" : 6, - "chi_square" : 7, - "cs" : 7, - "kullback_leibler" : 8, - "kl" : 8, + + distance_translation = {'euclidean': 1, + 'manhattan': 2, + 'minkowski': 3, + 'max_dist': 4, + 'hik': 5, + 'hellinger': 6, + 'chi_square': 7, + 'cs': 7, + 'kullback_leibler': 8, + 'kl': 8, } - if type(distance_type)==str: + if isinstance(distance_type, str): distance_type = distance_translation[distance_type] - flannlib.flann_set_distance_type(distance_type,order) - + flannlib.flann_set_distance_type(distance_type, order) def to_bytes(string): if sys.hexversion > 0x03000000: - return bytes(string,'utf-8') + return bytes(string, 'utf-8') return string -# This class is derived from an initial implementation by Hoyt Koepke (hoytak@cs.ubc.ca) -class FLANN: +# This class is derived from an initial implementation by Hoyt Koepke +# (hoytak@cs.ubc.ca) + + +class FLANN(object): """ This class defines a python interface to the FLANN lirary. """ __rn_gen = _rn.RandomState() - - _as_parameter_ = property( lambda self: self.__curindex ) + + _as_parameter_ = property(lambda self: self.__curindex) def __init__(self, **kwargs): """ @@ -75,64 +84,63 @@ class FLANN: the flann libraries. Any keyword arguments passed to __init__ override the global defaults given. """ - + self.__rn_gen.seed() self.__curindex = None self.__curindex_data = None self.__curindex_type = None - - self.__flann_parameters = FLANNParameters() + + self.__flann_parameters = FLANNParameters() self.__flann_parameters.update(kwargs) def __del__(self): self.delete_index() - - ################################################################################ + ########################################################################## # actual workhorse functions - def nn(self, pts, qpts, num_neighbors = 1, **kwargs): + def nn(self, pts, qpts, num_neighbors=1, **kwargs): """ Returns the num_neighbors nearest points in dataset for each point in testset. """ - - if not pts.dtype.type in allowed_types: - raise FLANNException("Cannot handle type: %s"%pts.dtype) - if not qpts.dtype.type in allowed_types: - raise FLANNException("Cannot handle type: %s"%pts.dtype) + if pts.dtype.type not in allowed_types: + raise FLANNException('Cannot handle type: %s' % pts.dtype) + + if qpts.dtype.type not in allowed_types: + raise FLANNException('Cannot handle type: %s' % pts.dtype) if pts.dtype != qpts.dtype: - raise FLANNException("Data and query must have the same type") - - pts = ensure_2d_array(pts,default_flags) - qpts = ensure_2d_array(qpts,default_flags) + raise FLANNException('Data and query must have the same type') + + pts = ensure_2d_array(pts, default_flags) + qpts = ensure_2d_array(qpts, default_flags) npts, dim = pts.shape nqpts = qpts.shape[0] - assert(qpts.shape[1] == dim) - assert(npts >= num_neighbors) + assert qpts.shape[1] == dim, 'data and query must have the same dims' + assert npts >= num_neighbors, 'more neighbors than there are points' - result = empty( (nqpts, num_neighbors), dtype=index_type) - if pts.dtype==float64: - dists = empty( (nqpts, num_neighbors), dtype=float64) + result = np.empty((nqpts, num_neighbors), dtype=index_type) + if pts.dtype == np.float64: + dists = np.empty((nqpts, num_neighbors), dtype=np.float64) else: - dists = empty( (nqpts, num_neighbors), dtype=float32) - + dists = np.empty((nqpts, num_neighbors), dtype=np.float32) + self.__flann_parameters.update(kwargs) - flann.find_nearest_neighbors[pts.dtype.type](pts, npts, dim, - qpts, nqpts, result, dists, num_neighbors, - pointer(self.__flann_parameters)) + flann.find_nearest_neighbors[ + pts.dtype.type]( + pts, npts, dim, qpts, nqpts, result, dists, num_neighbors, + pointer(self.__flann_parameters)) if num_neighbors == 1: - return (result.reshape( nqpts ), dists.reshape(nqpts)) + return (result.reshape(nqpts), dists.reshape(nqpts)) else: - return (result,dists) - + return (result, dists) def build_index(self, pts, **kwargs): """ @@ -143,80 +151,85 @@ class FLANN: the nearest neighbors in this index. pts is a 2d numpy array or matrix. All the computation is done - in float32 type, but pts may be any type that is convertable - to float32. + in np.float32 type, but pts may be any type that is convertable + to np.float32. """ - - if not pts.dtype.type in allowed_types: - raise FLANNException("Cannot handle type: %s"%pts.dtype) - pts = ensure_2d_array(pts,default_flags) + if pts.dtype.type not in allowed_types: + raise FLANNException('Cannot handle type: %s' % pts.dtype) + + pts = ensure_2d_array(pts, default_flags) npts, dim = pts.shape - + self.__ensureRandomSeed(kwargs) - + self.__flann_parameters.update(kwargs) - if self.__curindex != None: - flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters)) + if self.__curindex is not None: + flann.free_index[self.__curindex_type]( + self.__curindex, pointer(self.__flann_parameters)) self.__curindex = None - + speedup = c_float(0) - self.__curindex = flann.build_index[pts.dtype.type](pts, npts, dim, byref(speedup), pointer(self.__flann_parameters)) + self.__curindex = flann.build_index[pts.dtype.type]( + pts, npts, dim, byref(speedup), pointer(self.__flann_parameters)) self.__curindex_data = pts self.__curindex_type = pts.dtype.type - + params = dict(self.__flann_parameters) - params["speedup"] = speedup.value - - return params + params['speedup'] = speedup.value + return params def save_index(self, filename): """ This saves the index to a disk file. """ - if self.__curindex != None: - flann.save_index[self.__curindex_type](self.__curindex, c_char_p(to_bytes(filename))) + if self.__curindex is not None: + flann.save_index[self.__curindex_type]( + self.__curindex, c_char_p(to_bytes(filename))) def load_index(self, filename, pts): """ Loads an index previously saved to disk. """ - - if not pts.dtype.type in allowed_types: - raise FLANNException("Cannot handle type: %s"%pts.dtype) - pts = ensure_2d_array(pts,default_flags) + if pts.dtype.type not in allowed_types: + raise FLANNException('Cannot handle type: %s' % pts.dtype) + + pts = ensure_2d_array(pts, default_flags) npts, dim = pts.shape - if self.__curindex != None: - flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters)) + if self.__curindex is not None: + flann.free_index[self.__curindex_type]( + self.__curindex, pointer(self.__flann_parameters)) self.__curindex = None self.__curindex_data = None self.__curindex_type = None - - self.__curindex = flann.load_index[pts.dtype.type](c_char_p(to_bytes(filename)), pts, npts, dim) + + self.__curindex = flann.load_index[pts.dtype.type]( + c_char_p(to_bytes(filename)), pts, npts, dim) self.__curindex_data = pts self.__curindex_type = pts.dtype.type - def nn_index(self, qpts, num_neighbors = 1, **kwargs): + def nn_index(self, qpts, num_neighbors=1, **kwargs): """ For each point in querypts, (which may be a single point), it returns the num_neighbors nearest points in the index built by calling build_index. """ - if self.__curindex == None: - raise FLANNException("build_index(...) method not called first or current index deleted.") + if self.__curindex is None: + raise FLANNException( + 'build_index(...) method not called first or current index deleted.') - if not qpts.dtype.type in allowed_types: - raise FLANNException("Cannot handle type: %s"%qpts.dtype) + if qpts.dtype.type not in allowed_types: + raise FLANNException('Cannot handle type: %s' % qpts.dtype) if self.__curindex_type != qpts.dtype.type: - raise FLANNException("Index and query must have the same type") + raise FLANNException('Index and query must have the same type') - qpts = ensure_2d_array(qpts,default_flags) + qpts = ensure_2d_array(qpts, default_flags) npts, dim = self.__curindex_data.shape @@ -225,122 +238,123 @@ class FLANN: nqpts = qpts.shape[0] - assert(qpts.shape[1] == dim) - assert(npts >= num_neighbors) - - result = empty( (nqpts, num_neighbors), dtype=index_type) - if self.__curindex_type==float64: - dists = empty( (nqpts, num_neighbors), dtype=float64) + assert qpts.shape[1] == dim, 'data and query must have the same dims' + assert npts >= num_neighbors, 'more neighbors than there are points' + + result = np.empty((nqpts, num_neighbors), dtype=index_type) + if self.__curindex_type == np.float64: + dists = np.empty((nqpts, num_neighbors), dtype=np.float64) else: - dists = empty( (nqpts, num_neighbors), dtype=float32) + dists = np.empty((nqpts, num_neighbors), dtype=np.float32) self.__flann_parameters.update(kwargs) - flann.find_nearest_neighbors_index[self.__curindex_type](self.__curindex, - qpts, nqpts, - result, dists, num_neighbors, - pointer(self.__flann_parameters)) + flann.find_nearest_neighbors_index[ + self.__curindex_type]( + self.__curindex, qpts, nqpts, result, dists, num_neighbors, + pointer(self.__flann_parameters)) if num_neighbors == 1: - return (result.reshape( nqpts ), dists.reshape( nqpts )) + return (result.reshape(nqpts), dists.reshape(nqpts)) else: - return (result,dists) - - + return (result, dists) + def nn_radius(self, query, radius, **kwargs): - - if self.__curindex == None: - raise FLANNException("build_index(...) method not called first or current index deleted.") - if not query.dtype.type in allowed_types: - raise FLANNException("Cannot handle type: %s"%query.dtype) + if self.__curindex is None: + raise FLANNException( + 'build_index(...) method not called first or current index deleted.') + + if query.dtype.type not in allowed_types: + raise FLANNException('Cannot handle type: %s' % query.dtype) if self.__curindex_type != query.dtype.type: - raise FLANNException("Index and query must have the same type") - - npts, dim = self.__curindex_data.shape - assert(query.shape[0]==dim) - - result = empty( npts, dtype=index_type) - if self.__curindex_type==float64: - dists = empty( npts, dtype=float64) + raise FLANNException('Index and query must have the same type') + + npts, dim = self.__curindex_data.shape + assert query.shape[0] == dim, 'data and query must have the same dims' + + result = np.empty(npts, dtype=index_type) + if self.__curindex_type == np.float64: + dists = np.empty(npts, dtype=np.float64) else: - dists = empty( npts, dtype=float32) - + dists = np.empty(npts, dtype=np.float32) + self.__flann_parameters.update(kwargs) - nn = flann.radius_search[self.__curindex_type](self.__curindex, query, - result, dists, npts, - radius, pointer(self.__flann_parameters)) - - - return (result[0:nn],dists[0:nn]) + nn = flann.radius_search[ + self.__curindex_type]( + self.__curindex, query, result, dists, npts, radius, + pointer(self.__flann_parameters)) + + return (result[0:nn], dists[0:nn]) def delete_index(self, **kwargs): """ - Deletes the current index freeing all the momory it uses. + Deletes the current index freeing all the momory it uses. The memory used by the dataset that was indexed is not freed. """ self.__flann_parameters.update(kwargs) - - if self.__curindex != None: - flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters)) + + if self.__curindex is not None: + flann.free_index[self.__curindex_type]( + self.__curindex, pointer(self.__flann_parameters)) self.__curindex = None self.__curindex_data = None - ########################################################################################## + ########################################################################## # Clustering functions - def kmeans(self, pts, num_clusters, max_iterations = None, - dtype = None, **kwargs): + def kmeans(self, pts, num_clusters, max_iterations=None, + dtype=None, **kwargs): """ Runs kmeans on pts with num_clusters centroids. Returns a - numpy array of size num_clusters x dim. + numpy array of size num_clusters x dim. If max_iterations is not None, the algorithm terminates after the given number of iterations regardless of convergence. The default is to run until convergence. If dtype is None (the default), the array returned is the same - type as pts. Otherwise, the returned array is of type dtype. + type as pts. Otherwise, the returned array is of type dtype. """ - + if int(num_clusters) != num_clusters or num_clusters < 1: raise FLANNException('num_clusters must be an integer >= 1') - + if num_clusters == 1: - if dtype == None or dtype == pts.dtype: - return mean(pts, 0).reshape(1, pts.shape[1]) + if dtype is None or dtype == pts.dtype: + return np.mean(pts, 0).reshape(1, pts.shape[1]) else: - return dtype(mean(pts, 0).reshape(1, pts.shape[1])) + return dtype(np.mean(pts, 0).reshape(1, pts.shape[1])) - return self.hierarchical_kmeans(pts, int(num_clusters), 1, - max_iterations, + return self.hierarchical_kmeans(pts, int(num_clusters), 1, + max_iterations, dtype, **kwargs) - + def hierarchical_kmeans(self, pts, branch_size, num_branches, - max_iterations = None, - dtype = None, **kwargs): + max_iterations=None, + dtype=None, **kwargs): """ Clusters the data by using multiple runs of kmeans to recursively partition the dataset. The number of resulting clusters is given by (branch_size-1)*num_branches+1. - + This method can be significantly faster when the number of desired clusters is quite large (e.g. a hundred or more). Higher branch sizes are slower but may give better results. If dtype is None (the default), the array returned is the same - type as pts. Otherwise, the returned array is of type dtype. - + type as pts. Otherwise, the returned array is of type dtype. + """ - + # First verify the paremeters are sensible. - if not pts.dtype.type in allowed_types: - raise FLANNException("Cannot handle type: %s"%pts.dtype) + if pts.dtype.type not in allowed_types: + raise FLANNException('Cannot handle type: %s' % pts.dtype) if int(branch_size) != branch_size or branch_size < 2: raise FLANNException('branch_size must be an integer >= 2.') @@ -352,51 +366,46 @@ class FLANN: num_branches = int(num_branches) - if max_iterations == None: + if max_iterations is None: max_iterations = -1 else: max_iterations = int(max_iterations) - # init the arrays and starting values - pts = ensure_2d_array(pts,default_flags) + pts = ensure_2d_array(pts, default_flags) npts, dim = pts.shape - num_clusters = (branch_size-1)*num_branches+1; - - if pts.dtype.type == float64: - result = empty( (num_clusters, dim), dtype=float64) + num_clusters = (branch_size - 1) * num_branches + 1 + + if pts.dtype.type == np.float64: + result = np.empty((num_clusters, dim), dtype=np.float64) else: - result = empty( (num_clusters, dim), dtype=float32) + result = np.empty((num_clusters, dim), dtype=np.float32) # set all the parameters appropriately - + self.__ensureRandomSeed(kwargs) - - params = {"iterations" : max_iterations, - "algorithm" : 'kmeans', - "branching" : branch_size, - "random_seed" : kwargs['random_seed']} - + + params = {'iterations': max_iterations, + 'algorithm': 'kmeans', + 'branching': branch_size, + 'random_seed': kwargs['random_seed']} + self.__flann_parameters.update(params) - - numclusters = flann.compute_cluster_centers[pts.dtype.type](pts, npts, dim, - num_clusters, result, - pointer(self.__flann_parameters)) + + numclusters = flann.compute_cluster_centers[pts.dtype.type]( + pts, npts, dim, num_clusters, result, + pointer(self.__flann_parameters)) if numclusters <= 0: raise FLANNException('Error occured during clustering procedure.') - if dtype == None: + if dtype is None: return result else: return dtype(result) - - ########################################################################################## + + ########################################################################## # internal bookkeeping functions - def __ensureRandomSeed(self, kwargs): - if not 'random_seed' in kwargs: - kwargs['random_seed'] = self.__rn_gen.randint(2**30) - - - + if 'random_seed' not in kwargs: + kwargs['random_seed'] = self.__rn_gen.randint(2 ** 30)