提交 77d3fbea 编写于 作者: J joncrall

Edited python files to conform to pep8

Replaced not 'a in i' to 'a not in i'
Replaced 'a == None' and 'a != None to 'a is None' and 'a is not None'
Replaced 'type(x) == type' to 'isinstance(x, type)'
Edited quotes to be consistent.
Replaced 'from x import *' with explicit imports
Used pep8 spacing in function signatures.
Changed classes to inherit from object
Removed unused imports
上级 495f61ff
...@@ -30,5 +30,3 @@ ...@@ -30,5 +30,3 @@
class FLANNException(Exception): class FLANNException(Exception):
def __init__(self, *args): def __init__(self, *args):
Exception.__init__(self, *args) Exception.__init__(self, *args)
...@@ -24,10 +24,14 @@ ...@@ -24,10 +24,14 @@
#(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF #(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
#THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from ctypes import * #from ctypes import *
#from ctypes.util import find_library #from ctypes.util import find_library
from numpy import float32, float64, uint8, int32, matrix, array, empty, reshape, require from numpy import (float32, float64, uint8, int32, require)
from numpy.ctypeslib import load_library, ndpointer #import ctypes
#import numpy as np
from ctypes import (Structure, c_char_p, c_int, c_float, c_uint, c_long,
c_void_p, cdll, POINTER)
from numpy.ctypeslib import ndpointer
import os import os
import sys import sys
...@@ -45,42 +49,43 @@ class CustomStructure(Structure): ...@@ -45,42 +49,43 @@ class CustomStructure(Structure):
def __init__(self): def __init__(self):
Structure.__init__(self) Structure.__init__(self)
self.__field_names = [ f for (f,t) in self._fields_] self.__field_names = [ f for (f, t) in self._fields_]
self.update(self._defaults_) self.update(self._defaults_)
def update(self, dict): def update(self, dict):
for k,v in dict.items(): for k, v in dict.items():
if k in self.__field_names: if k in self.__field_names:
setattr(self,k,self.__translate(k,v)) setattr(self, k, self.__translate(k, v))
else: else:
raise KeyError("No such member: "+k) raise KeyError('No such member: ' + k)
def __getitem__(self, k): def __getitem__(self, k):
if k in self.__field_names: if k in self.__field_names:
return self.__translate_back(k,getattr(self,k)) return self.__translate_back(k, getattr(self, k))
def __setitem__(self, k, v): def __setitem__(self, k, v):
if k in self.__field_names: if k in self.__field_names:
setattr(self,k,self.__translate(k,v)) setattr(self, k, self.__translate(k, v))
else: else:
raise KeyError("No such member: "+k) raise KeyError('No such member: ' + k)
def keys(self): def keys(self):
return self.__field_names return self.__field_names
def __translate(self,k,v): def __translate(self, k, v):
if k in self._translation_: if k in self._translation_:
if v in self._translation_[k]: if v in self._translation_[k]:
return self._translation_[k][v] return self._translation_[k][v]
return v return v
def __translate_back(self,k,v): def __translate_back(self, k, v):
if k in self._translation_: if k in self._translation_:
for tk,tv in self._translation_[k].items(): for tk, tv in self._translation_[k].items():
if tv==v: if tv == v:
return tk return tk
return v return v
class FLANNParameters(CustomStructure): class FLANNParameters(CustomStructure):
_fields_ = [ _fields_ = [
('algorithm', c_int), ('algorithm', c_int),
...@@ -125,13 +130,13 @@ class FLANNParameters(CustomStructure): ...@@ -125,13 +130,13 @@ class FLANNParameters(CustomStructure):
'table_number_': 12, 'table_number_': 12,
'key_size_': 20, 'key_size_': 20,
'multi_probe_level_': 2, 'multi_probe_level_': 2,
'log_level' : "warning", 'log_level' : 'warning',
'random_seed' : -1 'random_seed' : -1
} }
_translation_ = { _translation_ = {
"algorithm" : {"linear" : 0, "kdtree" : 1, "kmeans" : 2, "composite" : 3, "kdtree_single" : 4, "hierarchical": 5, "lsh": 6, "saved": 254, "autotuned" : 255, "default" : 1}, 'algorithm' : {'linear' : 0, 'kdtree' : 1, 'kmeans' : 2, 'composite' : 3, 'kdtree_single' : 4, 'hierarchical': 5, 'lsh': 6, 'saved': 254, 'autotuned' : 255, 'default' : 1},
"centers_init" : {"random" : 0, "gonzales" : 1, "kmeanspp" : 2, "default" : 0}, 'centers_init' : {'random' : 0, 'gonzales' : 1, 'kmeanspp' : 2, 'default' : 0},
"log_level" : {"none" : 0, "fatal" : 1, "error" : 2, "warning" : 3, "info" : 4, "default" : 2} 'log_level' : {'none' : 0, 'fatal' : 1, 'error' : 2, 'warning' : 3, 'info' : 4, 'default' : 2}
} }
...@@ -140,6 +145,7 @@ allowed_types = [ float32, float64, uint8, int32] ...@@ -140,6 +145,7 @@ allowed_types = [ float32, float64, uint8, int32]
FLANN_INDEX = c_void_p FLANN_INDEX = c_void_p
def load_flann_library(): def load_flann_library():
root_dir = os.path.abspath(os.path.dirname(__file__)) root_dir = os.path.abspath(os.path.dirname(__file__))
...@@ -151,16 +157,16 @@ def load_flann_library(): ...@@ -151,16 +157,16 @@ def load_flann_library():
elif sys.platform == 'darwin': elif sys.platform == 'darwin':
libnames = ['libflann.dylib'] libnames = ['libflann.dylib']
while root_dir!=None: while root_dir is not None:
for libname in libnames: for libname in libnames:
try: try:
#print "Trying ",os.path.join(root_dir,'lib',libname) #print 'Trying ',os.path.join(root_dir,'lib',libname)
flannlib = cdll[os.path.join(root_dir,libdir,libname)] flannlib = cdll[os.path.join(root_dir, libdir, libname)]
return flannlib return flannlib
except Exception: except Exception:
pass pass
try: try:
flannlib = cdll[os.path.join(root_dir,"build",libdir,libname)] flannlib = cdll[os.path.join(root_dir, 'build', libdir, libname)]
return flannlib return flannlib
except Exception: except Exception:
pass pass
...@@ -174,8 +180,8 @@ def load_flann_library(): ...@@ -174,8 +180,8 @@ def load_flann_library():
# a full path as a last resort # a full path as a last resort
for libname in libnames: for libname in libnames:
try: try:
#print "Trying",libname #print 'Trying',libname
flannlib=cdll[libname] flannlib = cdll[libname]
return flannlib return flannlib
except: except:
pass pass
...@@ -183,10 +189,13 @@ def load_flann_library(): ...@@ -183,10 +189,13 @@ def load_flann_library():
return None return None
flannlib = load_flann_library() flannlib = load_flann_library()
if flannlib == None: if flannlib is None:
raise ImportError('Cannot load dynamic library. Did you compile FLANN?') raise ImportError('Cannot load dynamic library. Did you compile FLANN?')
class FlannLib: pass
class FlannLib(object):
pass
flann = FlannLib() flann = FlannLib()
...@@ -196,27 +205,27 @@ flannlib.flann_log_verbosity.argtypes = [ ...@@ -196,27 +205,27 @@ flannlib.flann_log_verbosity.argtypes = [
] ]
flannlib.flann_set_distance_type.restype = None flannlib.flann_set_distance_type.restype = None
flannlib.flann_set_distance_type.argtypes = [ flannlib.flann_set_distance_type.argtypes = [
c_int, c_int,
c_int, c_int,
] ]
type_mappings = ( ('float','float32'), type_mappings = ( ('float', 'float32'),
('double','float64'), ('double', 'float64'),
('byte','uint8'), ('byte', 'uint8'),
('int','int32') ) ('int', 'int32') )
def define_functions(str): def define_functions(str):
for type in type_mappings: for type in type_mappings:
eval(compile(str%{'C':type[0],'numpy':type[1]},"<string>","exec")) eval(compile(str % {'C': type[0], 'numpy': type[1]}, '<string>', 'exec'))
flann.build_index = {} flann.build_index = {}
define_functions(r""" define_functions(r"""
flannlib.flann_build_index_%(C)s.restype = FLANN_INDEX flannlib.flann_build_index_%(C)s.restype = FLANN_INDEX
flannlib.flann_build_index_%(C)s.argtypes = [ flannlib.flann_build_index_%(C)s.argtypes = [
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows c_int, # rows
c_int, # cols c_int, # cols
POINTER(c_float), # speedup POINTER(c_float), # speedup
...@@ -240,7 +249,7 @@ define_functions(r""" ...@@ -240,7 +249,7 @@ define_functions(r"""
flannlib.flann_load_index_%(C)s.restype = FLANN_INDEX flannlib.flann_load_index_%(C)s.restype = FLANN_INDEX
flannlib.flann_load_index_%(C)s.argtypes = [ flannlib.flann_load_index_%(C)s.argtypes = [
c_char_p, #filename c_char_p, #filename
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows c_int, # rows
c_int, # cols c_int, # cols
] ]
...@@ -251,13 +260,13 @@ flann.find_nearest_neighbors = {} ...@@ -251,13 +260,13 @@ flann.find_nearest_neighbors = {}
define_functions(r""" define_functions(r"""
flannlib.flann_find_nearest_neighbors_%(C)s.restype = c_int flannlib.flann_find_nearest_neighbors_%(C)s.restype = c_int
flannlib.flann_find_nearest_neighbors_%(C)s.argtypes = [ flannlib.flann_find_nearest_neighbors_%(C)s.argtypes = [
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows c_int, # rows
c_int, # cols c_int, # cols
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # testset ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float32, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn c_int, # nn
POINTER(FLANNParameters) # flann_params POINTER(FLANNParameters) # flann_params
] ]
...@@ -268,13 +277,13 @@ flann.find_nearest_neighbors[%(numpy)s] = flannlib.flann_find_nearest_neighbors_ ...@@ -268,13 +277,13 @@ flann.find_nearest_neighbors[%(numpy)s] = flannlib.flann_find_nearest_neighbors_
flannlib.flann_find_nearest_neighbors_double.restype = c_int flannlib.flann_find_nearest_neighbors_double.restype = c_int
flannlib.flann_find_nearest_neighbors_double.argtypes = [ flannlib.flann_find_nearest_neighbors_double.argtypes = [
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # dataset ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows c_int, # rows
c_int, # cols c_int, # cols
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # testset ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn c_int, # nn
POINTER(FLANNParameters) # flann_params POINTER(FLANNParameters) # flann_params
] ]
...@@ -286,10 +295,10 @@ define_functions(r""" ...@@ -286,10 +295,10 @@ define_functions(r"""
flannlib.flann_find_nearest_neighbors_index_%(C)s.restype = c_int flannlib.flann_find_nearest_neighbors_index_%(C)s.restype = c_int
flannlib.flann_find_nearest_neighbors_index_%(C)s.argtypes = [ flannlib.flann_find_nearest_neighbors_index_%(C)s.argtypes = [
FLANN_INDEX, # index_id FLANN_INDEX, # index_id
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # testset ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float32, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn c_int, # nn
POINTER(FLANNParameters) # flann_params POINTER(FLANNParameters) # flann_params
] ]
...@@ -299,10 +308,10 @@ flann.find_nearest_neighbors_index[%(numpy)s] = flannlib.flann_find_nearest_neig ...@@ -299,10 +308,10 @@ flann.find_nearest_neighbors_index[%(numpy)s] = flannlib.flann_find_nearest_neig
flannlib.flann_find_nearest_neighbors_index_double.restype = c_int flannlib.flann_find_nearest_neighbors_index_double.restype = c_int
flannlib.flann_find_nearest_neighbors_index_double.argtypes = [ flannlib.flann_find_nearest_neighbors_index_double.argtypes = [
FLANN_INDEX, # index_id FLANN_INDEX, # index_id
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # testset ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn c_int, # nn
POINTER(FLANNParameters) # flann_params POINTER(FLANNParameters) # flann_params
] ]
...@@ -313,9 +322,9 @@ define_functions(r""" ...@@ -313,9 +322,9 @@ define_functions(r"""
flannlib.flann_radius_search_%(C)s.restype = c_int flannlib.flann_radius_search_%(C)s.restype = c_int
flannlib.flann_radius_search_%(C)s.argtypes = [ flannlib.flann_radius_search_%(C)s.argtypes = [
FLANN_INDEX, # index_id FLANN_INDEX, # index_id
ndpointer(%(numpy)s, ndim = 1, flags='aligned, c_contiguous'), # query ndpointer(%(numpy)s, ndim=1, flags='aligned, c_contiguous'), # query
ndpointer(int32, ndim = 1, flags='aligned, c_contiguous, writeable'), # indices ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'), # indices
ndpointer(float32, ndim = 1, flags='aligned, c_contiguous, writeable'), # dists ndpointer(float32, ndim=1, flags='aligned, c_contiguous, writeable'), # dists
c_int, # max_nn c_int, # max_nn
c_float, # radius c_float, # radius
POINTER(FLANNParameters) # flann_params POINTER(FLANNParameters) # flann_params
...@@ -326,9 +335,9 @@ flann.radius_search[%(numpy)s] = flannlib.flann_radius_search_%(C)s ...@@ -326,9 +335,9 @@ flann.radius_search[%(numpy)s] = flannlib.flann_radius_search_%(C)s
flannlib.flann_radius_search_double.restype = c_int flannlib.flann_radius_search_double.restype = c_int
flannlib.flann_radius_search_double.argtypes = [ flannlib.flann_radius_search_double.argtypes = [
FLANN_INDEX, # index_id FLANN_INDEX, # index_id
ndpointer(float64, ndim = 1, flags='aligned, c_contiguous'), # query ndpointer(float64, ndim=1, flags='aligned, c_contiguous'), # query
ndpointer(int32, ndim = 1, flags='aligned, c_contiguous, writeable'), # indices ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'), # indices
ndpointer(float64, ndim = 1, flags='aligned, c_contiguous, writeable'), # dists ndpointer(float64, ndim=1, flags='aligned, c_contiguous, writeable'), # dists
c_int, # max_nn c_int, # max_nn
c_float, # radius c_float, # radius
POINTER(FLANNParameters) # flann_params POINTER(FLANNParameters) # flann_params
...@@ -340,7 +349,7 @@ flann.compute_cluster_centers = {} ...@@ -340,7 +349,7 @@ flann.compute_cluster_centers = {}
define_functions(r""" define_functions(r"""
flannlib.flann_compute_cluster_centers_%(C)s.restype = c_int flannlib.flann_compute_cluster_centers_%(C)s.restype = c_int
flannlib.flann_compute_cluster_centers_%(C)s.argtypes = [ flannlib.flann_compute_cluster_centers_%(C)s.argtypes = [
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows c_int, # rows
c_int, # cols c_int, # cols
c_int, # clusters c_int, # clusters
...@@ -352,7 +361,7 @@ flann.compute_cluster_centers[%(numpy)s] = flannlib.flann_compute_cluster_center ...@@ -352,7 +361,7 @@ flann.compute_cluster_centers[%(numpy)s] = flannlib.flann_compute_cluster_center
# double is an exception # double is an exception
flannlib.flann_compute_cluster_centers_double.restype = c_int flannlib.flann_compute_cluster_centers_double.restype = c_int
flannlib.flann_compute_cluster_centers_double.argtypes = [ flannlib.flann_compute_cluster_centers_double.argtypes = [
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # dataset ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows c_int, # rows
c_int, # cols c_int, # cols
c_int, # clusters c_int, # clusters
...@@ -373,8 +382,8 @@ flann.free_index[%(numpy)s] = flannlib.flann_free_index_%(C)s ...@@ -373,8 +382,8 @@ flann.free_index[%(numpy)s] = flannlib.flann_free_index_%(C)s
""") """)
def ensure_2d_array(array, flags, **kwargs): def ensure_2d_array(arr, flags, **kwargs):
array = require(array, requirements = flags, **kwargs) arr = require(arr, requirements=flags, **kwargs)
if len(array.shape) == 1: if len(arr.shape) == 1:
array = array.reshape(-1,array.size) arr = arr.reshape(-1, arr.size)
return array return arr
...@@ -24,50 +24,59 @@ ...@@ -24,50 +24,59 @@
#(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF #(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
#THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from pyflann.flann_ctypes import * #from pyflann.flann_ctypes import * # NOQA
from pyflann.exceptions import * import sys
from ctypes import pointer, c_float, byref, c_char_p
from pyflann.flann_ctypes import (flannlib, FLANNParameters, allowed_types,
ensure_2d_array, default_flags, flann)
import numpy as np
from pyflann.exceptions import FLANNException
import numpy.random as _rn import numpy.random as _rn
index_type = int32 index_type = np.int32
def set_distance_type(distance_type, order = 0):
def set_distance_type(distance_type, order=0):
""" """
Sets the distance type used. Possible values: euclidean, manhattan, minkowski, max_dist, Sets the distance type used. Possible values: euclidean, manhattan, minkowski, max_dist,
hik, hellinger, cs, kl. hik, hellinger, cs, kl.
""" """
distance_translation = { "euclidean" : 1, distance_translation = {'euclidean': 1,
"manhattan" : 2, 'manhattan': 2,
"minkowski" : 3, 'minkowski': 3,
"max_dist" : 4, 'max_dist': 4,
"hik" : 5, 'hik': 5,
"hellinger" : 6, 'hellinger': 6,
"chi_square" : 7, 'chi_square': 7,
"cs" : 7, 'cs': 7,
"kullback_leibler" : 8, 'kullback_leibler': 8,
"kl" : 8, 'kl': 8,
} }
if type(distance_type)==str: if isinstance(distance_type, str):
distance_type = distance_translation[distance_type] distance_type = distance_translation[distance_type]
flannlib.flann_set_distance_type(distance_type,order) flannlib.flann_set_distance_type(distance_type, order)
def to_bytes(string): def to_bytes(string):
if sys.hexversion > 0x03000000: if sys.hexversion > 0x03000000:
return bytes(string,'utf-8') return bytes(string, 'utf-8')
return string return string
# This class is derived from an initial implementation by Hoyt Koepke (hoytak@cs.ubc.ca) # This class is derived from an initial implementation by Hoyt Koepke
class FLANN: # (hoytak@cs.ubc.ca)
class FLANN(object):
""" """
This class defines a python interface to the FLANN lirary. This class defines a python interface to the FLANN lirary.
""" """
__rn_gen = _rn.RandomState() __rn_gen = _rn.RandomState()
_as_parameter_ = property( lambda self: self.__curindex ) _as_parameter_ = property(lambda self: self.__curindex)
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
...@@ -88,51 +97,50 @@ class FLANN: ...@@ -88,51 +97,50 @@ class FLANN:
def __del__(self): def __del__(self):
self.delete_index() self.delete_index()
##########################################################################
################################################################################
# actual workhorse functions # actual workhorse functions
def nn(self, pts, qpts, num_neighbors = 1, **kwargs): def nn(self, pts, qpts, num_neighbors=1, **kwargs):
""" """
Returns the num_neighbors nearest points in dataset for each point Returns the num_neighbors nearest points in dataset for each point
in testset. in testset.
""" """
if not pts.dtype.type in allowed_types: if pts.dtype.type not in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype) raise FLANNException('Cannot handle type: %s' % pts.dtype)
if not qpts.dtype.type in allowed_types: if qpts.dtype.type not in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype) raise FLANNException('Cannot handle type: %s' % pts.dtype)
if pts.dtype != qpts.dtype: if pts.dtype != qpts.dtype:
raise FLANNException("Data and query must have the same type") raise FLANNException('Data and query must have the same type')
pts = ensure_2d_array(pts,default_flags) pts = ensure_2d_array(pts, default_flags)
qpts = ensure_2d_array(qpts,default_flags) qpts = ensure_2d_array(qpts, default_flags)
npts, dim = pts.shape npts, dim = pts.shape
nqpts = qpts.shape[0] nqpts = qpts.shape[0]
assert(qpts.shape[1] == dim) assert qpts.shape[1] == dim, 'data and query must have the same dims'
assert(npts >= num_neighbors) assert npts >= num_neighbors, 'more neighbors than there are points'
result = empty( (nqpts, num_neighbors), dtype=index_type) result = np.empty((nqpts, num_neighbors), dtype=index_type)
if pts.dtype==float64: if pts.dtype == np.float64:
dists = empty( (nqpts, num_neighbors), dtype=float64) dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
else: else:
dists = empty( (nqpts, num_neighbors), dtype=float32) dists = np.empty((nqpts, num_neighbors), dtype=np.float32)
self.__flann_parameters.update(kwargs) self.__flann_parameters.update(kwargs)
flann.find_nearest_neighbors[pts.dtype.type](pts, npts, dim, flann.find_nearest_neighbors[
qpts, nqpts, result, dists, num_neighbors, pts.dtype.type](
pts, npts, dim, qpts, nqpts, result, dists, num_neighbors,
pointer(self.__flann_parameters)) pointer(self.__flann_parameters))
if num_neighbors == 1: if num_neighbors == 1:
return (result.reshape( nqpts ), dists.reshape(nqpts)) return (result.reshape(nqpts), dists.reshape(nqpts))
else: else:
return (result,dists) return (result, dists)
def build_index(self, pts, **kwargs): def build_index(self, pts, **kwargs):
""" """
...@@ -143,80 +151,85 @@ class FLANN: ...@@ -143,80 +151,85 @@ class FLANN:
the nearest neighbors in this index. the nearest neighbors in this index.
pts is a 2d numpy array or matrix. All the computation is done pts is a 2d numpy array or matrix. All the computation is done
in float32 type, but pts may be any type that is convertable in np.float32 type, but pts may be any type that is convertable
to float32. to np.float32.
""" """
if not pts.dtype.type in allowed_types: if pts.dtype.type not in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype) raise FLANNException('Cannot handle type: %s' % pts.dtype)
pts = ensure_2d_array(pts,default_flags) pts = ensure_2d_array(pts, default_flags)
npts, dim = pts.shape npts, dim = pts.shape
self.__ensureRandomSeed(kwargs) self.__ensureRandomSeed(kwargs)
self.__flann_parameters.update(kwargs) self.__flann_parameters.update(kwargs)
if self.__curindex != None: if self.__curindex is not None:
flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters)) flann.free_index[self.__curindex_type](
self.__curindex, pointer(self.__flann_parameters))
self.__curindex = None self.__curindex = None
speedup = c_float(0) speedup = c_float(0)
self.__curindex = flann.build_index[pts.dtype.type](pts, npts, dim, byref(speedup), pointer(self.__flann_parameters)) self.__curindex = flann.build_index[pts.dtype.type](
pts, npts, dim, byref(speedup), pointer(self.__flann_parameters))
self.__curindex_data = pts self.__curindex_data = pts
self.__curindex_type = pts.dtype.type self.__curindex_type = pts.dtype.type
params = dict(self.__flann_parameters) params = dict(self.__flann_parameters)
params["speedup"] = speedup.value params['speedup'] = speedup.value
return params return params
def save_index(self, filename): def save_index(self, filename):
""" """
This saves the index to a disk file. This saves the index to a disk file.
""" """
if self.__curindex != None: if self.__curindex is not None:
flann.save_index[self.__curindex_type](self.__curindex, c_char_p(to_bytes(filename))) flann.save_index[self.__curindex_type](
self.__curindex, c_char_p(to_bytes(filename)))
def load_index(self, filename, pts): def load_index(self, filename, pts):
""" """
Loads an index previously saved to disk. Loads an index previously saved to disk.
""" """
if not pts.dtype.type in allowed_types: if pts.dtype.type not in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype) raise FLANNException('Cannot handle type: %s' % pts.dtype)
pts = ensure_2d_array(pts,default_flags) pts = ensure_2d_array(pts, default_flags)
npts, dim = pts.shape npts, dim = pts.shape
if self.__curindex != None: if self.__curindex is not None:
flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters)) flann.free_index[self.__curindex_type](
self.__curindex, pointer(self.__flann_parameters))
self.__curindex = None self.__curindex = None
self.__curindex_data = None self.__curindex_data = None
self.__curindex_type = None self.__curindex_type = None
self.__curindex = flann.load_index[pts.dtype.type](c_char_p(to_bytes(filename)), pts, npts, dim) self.__curindex = flann.load_index[pts.dtype.type](
c_char_p(to_bytes(filename)), pts, npts, dim)
self.__curindex_data = pts self.__curindex_data = pts
self.__curindex_type = pts.dtype.type self.__curindex_type = pts.dtype.type
def nn_index(self, qpts, num_neighbors = 1, **kwargs): def nn_index(self, qpts, num_neighbors=1, **kwargs):
""" """
For each point in querypts, (which may be a single point), it For each point in querypts, (which may be a single point), it
returns the num_neighbors nearest points in the index built by returns the num_neighbors nearest points in the index built by
calling build_index. calling build_index.
""" """
if self.__curindex == None: if self.__curindex is None:
raise FLANNException("build_index(...) method not called first or current index deleted.") raise FLANNException(
'build_index(...) method not called first or current index deleted.')
if not qpts.dtype.type in allowed_types: if qpts.dtype.type not in allowed_types:
raise FLANNException("Cannot handle type: %s"%qpts.dtype) raise FLANNException('Cannot handle type: %s' % qpts.dtype)
if self.__curindex_type != qpts.dtype.type: if self.__curindex_type != qpts.dtype.type:
raise FLANNException("Index and query must have the same type") raise FLANNException('Index and query must have the same type')
qpts = ensure_2d_array(qpts,default_flags) qpts = ensure_2d_array(qpts, default_flags)
npts, dim = self.__curindex_data.shape npts, dim = self.__curindex_data.shape
...@@ -225,56 +238,56 @@ class FLANN: ...@@ -225,56 +238,56 @@ class FLANN:
nqpts = qpts.shape[0] nqpts = qpts.shape[0]
assert(qpts.shape[1] == dim) assert qpts.shape[1] == dim, 'data and query must have the same dims'
assert(npts >= num_neighbors) assert npts >= num_neighbors, 'more neighbors than there are points'
result = empty( (nqpts, num_neighbors), dtype=index_type) result = np.empty((nqpts, num_neighbors), dtype=index_type)
if self.__curindex_type==float64: if self.__curindex_type == np.float64:
dists = empty( (nqpts, num_neighbors), dtype=float64) dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
else: else:
dists = empty( (nqpts, num_neighbors), dtype=float32) dists = np.empty((nqpts, num_neighbors), dtype=np.float32)
self.__flann_parameters.update(kwargs) self.__flann_parameters.update(kwargs)
flann.find_nearest_neighbors_index[self.__curindex_type](self.__curindex, flann.find_nearest_neighbors_index[
qpts, nqpts, self.__curindex_type](
result, dists, num_neighbors, self.__curindex, qpts, nqpts, result, dists, num_neighbors,
pointer(self.__flann_parameters)) pointer(self.__flann_parameters))
if num_neighbors == 1: if num_neighbors == 1:
return (result.reshape( nqpts ), dists.reshape( nqpts )) return (result.reshape(nqpts), dists.reshape(nqpts))
else: else:
return (result,dists) return (result, dists)
def nn_radius(self, query, radius, **kwargs): def nn_radius(self, query, radius, **kwargs):
if self.__curindex == None: if self.__curindex is None:
raise FLANNException("build_index(...) method not called first or current index deleted.") raise FLANNException(
'build_index(...) method not called first or current index deleted.')
if not query.dtype.type in allowed_types: if query.dtype.type not in allowed_types:
raise FLANNException("Cannot handle type: %s"%query.dtype) raise FLANNException('Cannot handle type: %s' % query.dtype)
if self.__curindex_type != query.dtype.type: if self.__curindex_type != query.dtype.type:
raise FLANNException("Index and query must have the same type") raise FLANNException('Index and query must have the same type')
npts, dim = self.__curindex_data.shape npts, dim = self.__curindex_data.shape
assert(query.shape[0]==dim) assert query.shape[0] == dim, 'data and query must have the same dims'
result = empty( npts, dtype=index_type) result = np.empty(npts, dtype=index_type)
if self.__curindex_type==float64: if self.__curindex_type == np.float64:
dists = empty( npts, dtype=float64) dists = np.empty(npts, dtype=np.float64)
else: else:
dists = empty( npts, dtype=float32) dists = np.empty(npts, dtype=np.float32)
self.__flann_parameters.update(kwargs) self.__flann_parameters.update(kwargs)
nn = flann.radius_search[self.__curindex_type](self.__curindex, query, nn = flann.radius_search[
result, dists, npts, self.__curindex_type](
radius, pointer(self.__flann_parameters)) self.__curindex, query, result, dists, npts, radius,
pointer(self.__flann_parameters))
return (result[0:nn],dists[0:nn]) return (result[0:nn], dists[0:nn])
def delete_index(self, **kwargs): def delete_index(self, **kwargs):
""" """
...@@ -284,16 +297,17 @@ class FLANN: ...@@ -284,16 +297,17 @@ class FLANN:
self.__flann_parameters.update(kwargs) self.__flann_parameters.update(kwargs)
if self.__curindex != None: if self.__curindex is not None:
flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters)) flann.free_index[self.__curindex_type](
self.__curindex, pointer(self.__flann_parameters))
self.__curindex = None self.__curindex = None
self.__curindex_data = None self.__curindex_data = None
########################################################################################## ##########################################################################
# Clustering functions # Clustering functions
def kmeans(self, pts, num_clusters, max_iterations = None, def kmeans(self, pts, num_clusters, max_iterations=None,
dtype = None, **kwargs): dtype=None, **kwargs):
""" """
Runs kmeans on pts with num_clusters centroids. Returns a Runs kmeans on pts with num_clusters centroids. Returns a
numpy array of size num_clusters x dim. numpy array of size num_clusters x dim.
...@@ -311,18 +325,18 @@ class FLANN: ...@@ -311,18 +325,18 @@ class FLANN:
raise FLANNException('num_clusters must be an integer >= 1') raise FLANNException('num_clusters must be an integer >= 1')
if num_clusters == 1: if num_clusters == 1:
if dtype == None or dtype == pts.dtype: if dtype is None or dtype == pts.dtype:
return mean(pts, 0).reshape(1, pts.shape[1]) return np.mean(pts, 0).reshape(1, pts.shape[1])
else: else:
return dtype(mean(pts, 0).reshape(1, pts.shape[1])) return dtype(np.mean(pts, 0).reshape(1, pts.shape[1]))
return self.hierarchical_kmeans(pts, int(num_clusters), 1, return self.hierarchical_kmeans(pts, int(num_clusters), 1,
max_iterations, max_iterations,
dtype, **kwargs) dtype, **kwargs)
def hierarchical_kmeans(self, pts, branch_size, num_branches, def hierarchical_kmeans(self, pts, branch_size, num_branches,
max_iterations = None, max_iterations=None,
dtype = None, **kwargs): dtype=None, **kwargs):
""" """
Clusters the data by using multiple runs of kmeans to Clusters the data by using multiple runs of kmeans to
recursively partition the dataset. The number of resulting recursively partition the dataset. The number of resulting
...@@ -339,8 +353,8 @@ class FLANN: ...@@ -339,8 +353,8 @@ class FLANN:
# First verify the paremeters are sensible. # First verify the paremeters are sensible.
if not pts.dtype.type in allowed_types: if pts.dtype.type not in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype) raise FLANNException('Cannot handle type: %s' % pts.dtype)
if int(branch_size) != branch_size or branch_size < 2: if int(branch_size) != branch_size or branch_size < 2:
raise FLANNException('branch_size must be an integer >= 2.') raise FLANNException('branch_size must be an integer >= 2.')
...@@ -352,51 +366,46 @@ class FLANN: ...@@ -352,51 +366,46 @@ class FLANN:
num_branches = int(num_branches) num_branches = int(num_branches)
if max_iterations == None: if max_iterations is None:
max_iterations = -1 max_iterations = -1
else: else:
max_iterations = int(max_iterations) max_iterations = int(max_iterations)
# init the arrays and starting values # init the arrays and starting values
pts = ensure_2d_array(pts,default_flags) pts = ensure_2d_array(pts, default_flags)
npts, dim = pts.shape npts, dim = pts.shape
num_clusters = (branch_size-1)*num_branches+1; num_clusters = (branch_size - 1) * num_branches + 1
if pts.dtype.type == float64: if pts.dtype.type == np.float64:
result = empty( (num_clusters, dim), dtype=float64) result = np.empty((num_clusters, dim), dtype=np.float64)
else: else:
result = empty( (num_clusters, dim), dtype=float32) result = np.empty((num_clusters, dim), dtype=np.float32)
# set all the parameters appropriately # set all the parameters appropriately
self.__ensureRandomSeed(kwargs) self.__ensureRandomSeed(kwargs)
params = {"iterations" : max_iterations, params = {'iterations': max_iterations,
"algorithm" : 'kmeans', 'algorithm': 'kmeans',
"branching" : branch_size, 'branching': branch_size,
"random_seed" : kwargs['random_seed']} 'random_seed': kwargs['random_seed']}
self.__flann_parameters.update(params) self.__flann_parameters.update(params)
numclusters = flann.compute_cluster_centers[pts.dtype.type](pts, npts, dim, numclusters = flann.compute_cluster_centers[pts.dtype.type](
num_clusters, result, pts, npts, dim, num_clusters, result,
pointer(self.__flann_parameters)) pointer(self.__flann_parameters))
if numclusters <= 0: if numclusters <= 0:
raise FLANNException('Error occured during clustering procedure.') raise FLANNException('Error occured during clustering procedure.')
if dtype == None: if dtype is None:
return result return result
else: else:
return dtype(result) return dtype(result)
########################################################################################## ##########################################################################
# internal bookkeeping functions # internal bookkeeping functions
def __ensureRandomSeed(self, kwargs): def __ensureRandomSeed(self, kwargs):
if not 'random_seed' in kwargs: if 'random_seed' not in kwargs:
kwargs['random_seed'] = self.__rn_gen.randint(2**30) kwargs['random_seed'] = self.__rn_gen.randint(2 ** 30)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册