提交 77d3fbea 编写于 作者: J joncrall

Edited python files to conform to pep8

Replaced not 'a in i' to 'a not in i'
Replaced 'a == None' and 'a != None to 'a is None' and 'a is not None'
Replaced 'type(x) == type' to 'isinstance(x, type)'
Edited quotes to be consistent.
Replaced 'from x import *' with explicit imports
Used pep8 spacing in function signatures.
Changed classes to inherit from object
Removed unused imports
上级 495f61ff
......@@ -30,5 +30,3 @@
class FLANNException(Exception):
def __init__(self, *args):
Exception.__init__(self, *args)
......@@ -24,10 +24,14 @@
#(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
#THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from ctypes import *
#from ctypes import *
#from ctypes.util import find_library
from numpy import float32, float64, uint8, int32, matrix, array, empty, reshape, require
from numpy.ctypeslib import load_library, ndpointer
from numpy import (float32, float64, uint8, int32, require)
#import ctypes
#import numpy as np
from ctypes import (Structure, c_char_p, c_int, c_float, c_uint, c_long,
c_void_p, cdll, POINTER)
from numpy.ctypeslib import ndpointer
import os
import sys
......@@ -45,42 +49,43 @@ class CustomStructure(Structure):
def __init__(self):
Structure.__init__(self)
self.__field_names = [ f for (f,t) in self._fields_]
self.__field_names = [ f for (f, t) in self._fields_]
self.update(self._defaults_)
def update(self, dict):
for k,v in dict.items():
for k, v in dict.items():
if k in self.__field_names:
setattr(self,k,self.__translate(k,v))
setattr(self, k, self.__translate(k, v))
else:
raise KeyError("No such member: "+k)
raise KeyError('No such member: ' + k)
def __getitem__(self, k):
if k in self.__field_names:
return self.__translate_back(k,getattr(self,k))
return self.__translate_back(k, getattr(self, k))
def __setitem__(self, k, v):
if k in self.__field_names:
setattr(self,k,self.__translate(k,v))
setattr(self, k, self.__translate(k, v))
else:
raise KeyError("No such member: "+k)
raise KeyError('No such member: ' + k)
def keys(self):
return self.__field_names
def __translate(self,k,v):
def __translate(self, k, v):
if k in self._translation_:
if v in self._translation_[k]:
return self._translation_[k][v]
return v
def __translate_back(self,k,v):
def __translate_back(self, k, v):
if k in self._translation_:
for tk,tv in self._translation_[k].items():
if tv==v:
for tk, tv in self._translation_[k].items():
if tv == v:
return tk
return v
class FLANNParameters(CustomStructure):
_fields_ = [
('algorithm', c_int),
......@@ -125,13 +130,13 @@ class FLANNParameters(CustomStructure):
'table_number_': 12,
'key_size_': 20,
'multi_probe_level_': 2,
'log_level' : "warning",
'log_level' : 'warning',
'random_seed' : -1
}
_translation_ = {
"algorithm" : {"linear" : 0, "kdtree" : 1, "kmeans" : 2, "composite" : 3, "kdtree_single" : 4, "hierarchical": 5, "lsh": 6, "saved": 254, "autotuned" : 255, "default" : 1},
"centers_init" : {"random" : 0, "gonzales" : 1, "kmeanspp" : 2, "default" : 0},
"log_level" : {"none" : 0, "fatal" : 1, "error" : 2, "warning" : 3, "info" : 4, "default" : 2}
'algorithm' : {'linear' : 0, 'kdtree' : 1, 'kmeans' : 2, 'composite' : 3, 'kdtree_single' : 4, 'hierarchical': 5, 'lsh': 6, 'saved': 254, 'autotuned' : 255, 'default' : 1},
'centers_init' : {'random' : 0, 'gonzales' : 1, 'kmeanspp' : 2, 'default' : 0},
'log_level' : {'none' : 0, 'fatal' : 1, 'error' : 2, 'warning' : 3, 'info' : 4, 'default' : 2}
}
......@@ -140,6 +145,7 @@ allowed_types = [ float32, float64, uint8, int32]
FLANN_INDEX = c_void_p
def load_flann_library():
root_dir = os.path.abspath(os.path.dirname(__file__))
......@@ -151,16 +157,16 @@ def load_flann_library():
elif sys.platform == 'darwin':
libnames = ['libflann.dylib']
while root_dir!=None:
while root_dir is not None:
for libname in libnames:
try:
#print "Trying ",os.path.join(root_dir,'lib',libname)
flannlib = cdll[os.path.join(root_dir,libdir,libname)]
#print 'Trying ',os.path.join(root_dir,'lib',libname)
flannlib = cdll[os.path.join(root_dir, libdir, libname)]
return flannlib
except Exception:
pass
try:
flannlib = cdll[os.path.join(root_dir,"build",libdir,libname)]
flannlib = cdll[os.path.join(root_dir, 'build', libdir, libname)]
return flannlib
except Exception:
pass
......@@ -174,8 +180,8 @@ def load_flann_library():
# a full path as a last resort
for libname in libnames:
try:
#print "Trying",libname
flannlib=cdll[libname]
#print 'Trying',libname
flannlib = cdll[libname]
return flannlib
except:
pass
......@@ -183,10 +189,13 @@ def load_flann_library():
return None
flannlib = load_flann_library()
if flannlib == None:
if flannlib is None:
raise ImportError('Cannot load dynamic library. Did you compile FLANN?')
class FlannLib: pass
class FlannLib(object):
pass
flann = FlannLib()
......@@ -196,27 +205,27 @@ flannlib.flann_log_verbosity.argtypes = [
]
flannlib.flann_set_distance_type.restype = None
flannlib.flann_set_distance_type.argtypes = [
c_int,
c_int,
]
type_mappings = ( ('float','float32'),
('double','float64'),
('byte','uint8'),
('int','int32') )
type_mappings = ( ('float', 'float32'),
('double', 'float64'),
('byte', 'uint8'),
('int', 'int32') )
def define_functions(str):
for type in type_mappings:
eval(compile(str%{'C':type[0],'numpy':type[1]},"<string>","exec"))
eval(compile(str % {'C': type[0], 'numpy': type[1]}, '<string>', 'exec'))
flann.build_index = {}
define_functions(r"""
flannlib.flann_build_index_%(C)s.restype = FLANN_INDEX
flannlib.flann_build_index_%(C)s.argtypes = [
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset
ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows
c_int, # cols
POINTER(c_float), # speedup
......@@ -240,7 +249,7 @@ define_functions(r"""
flannlib.flann_load_index_%(C)s.restype = FLANN_INDEX
flannlib.flann_load_index_%(C)s.argtypes = [
c_char_p, #filename
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset
ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows
c_int, # cols
]
......@@ -251,13 +260,13 @@ flann.find_nearest_neighbors = {}
define_functions(r"""
flannlib.flann_find_nearest_neighbors_%(C)s.restype = c_int
flannlib.flann_find_nearest_neighbors_%(C)s.argtypes = [
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset
ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows
c_int, # cols
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # testset
ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float32, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists
ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn
POINTER(FLANNParameters) # flann_params
]
......@@ -268,13 +277,13 @@ flann.find_nearest_neighbors[%(numpy)s] = flannlib.flann_find_nearest_neighbors_
flannlib.flann_find_nearest_neighbors_double.restype = c_int
flannlib.flann_find_nearest_neighbors_double.argtypes = [
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # dataset
ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows
c_int, # cols
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # testset
ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists
ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn
POINTER(FLANNParameters) # flann_params
]
......@@ -286,10 +295,10 @@ define_functions(r"""
flannlib.flann_find_nearest_neighbors_index_%(C)s.restype = c_int
flannlib.flann_find_nearest_neighbors_index_%(C)s.argtypes = [
FLANN_INDEX, # index_id
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # testset
ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float32, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists
ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float32, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn
POINTER(FLANNParameters) # flann_params
]
......@@ -299,10 +308,10 @@ flann.find_nearest_neighbors_index[%(numpy)s] = flannlib.flann_find_nearest_neig
flannlib.flann_find_nearest_neighbors_index_double.restype = c_int
flannlib.flann_find_nearest_neighbors_index_double.argtypes = [
FLANN_INDEX, # index_id
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # testset
ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # testset
c_int, # tcount
ndpointer(int32, ndim = 2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous, writeable'), # dists
ndpointer(int32, ndim=2, flags='aligned, c_contiguous, writeable'), # result
ndpointer(float64, ndim=2, flags='aligned, c_contiguous, writeable'), # dists
c_int, # nn
POINTER(FLANNParameters) # flann_params
]
......@@ -313,9 +322,9 @@ define_functions(r"""
flannlib.flann_radius_search_%(C)s.restype = c_int
flannlib.flann_radius_search_%(C)s.argtypes = [
FLANN_INDEX, # index_id
ndpointer(%(numpy)s, ndim = 1, flags='aligned, c_contiguous'), # query
ndpointer(int32, ndim = 1, flags='aligned, c_contiguous, writeable'), # indices
ndpointer(float32, ndim = 1, flags='aligned, c_contiguous, writeable'), # dists
ndpointer(%(numpy)s, ndim=1, flags='aligned, c_contiguous'), # query
ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'), # indices
ndpointer(float32, ndim=1, flags='aligned, c_contiguous, writeable'), # dists
c_int, # max_nn
c_float, # radius
POINTER(FLANNParameters) # flann_params
......@@ -326,9 +335,9 @@ flann.radius_search[%(numpy)s] = flannlib.flann_radius_search_%(C)s
flannlib.flann_radius_search_double.restype = c_int
flannlib.flann_radius_search_double.argtypes = [
FLANN_INDEX, # index_id
ndpointer(float64, ndim = 1, flags='aligned, c_contiguous'), # query
ndpointer(int32, ndim = 1, flags='aligned, c_contiguous, writeable'), # indices
ndpointer(float64, ndim = 1, flags='aligned, c_contiguous, writeable'), # dists
ndpointer(float64, ndim=1, flags='aligned, c_contiguous'), # query
ndpointer(int32, ndim=1, flags='aligned, c_contiguous, writeable'), # indices
ndpointer(float64, ndim=1, flags='aligned, c_contiguous, writeable'), # dists
c_int, # max_nn
c_float, # radius
POINTER(FLANNParameters) # flann_params
......@@ -340,7 +349,7 @@ flann.compute_cluster_centers = {}
define_functions(r"""
flannlib.flann_compute_cluster_centers_%(C)s.restype = c_int
flannlib.flann_compute_cluster_centers_%(C)s.argtypes = [
ndpointer(%(numpy)s, ndim = 2, flags='aligned, c_contiguous'), # dataset
ndpointer(%(numpy)s, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows
c_int, # cols
c_int, # clusters
......@@ -352,7 +361,7 @@ flann.compute_cluster_centers[%(numpy)s] = flannlib.flann_compute_cluster_center
# double is an exception
flannlib.flann_compute_cluster_centers_double.restype = c_int
flannlib.flann_compute_cluster_centers_double.argtypes = [
ndpointer(float64, ndim = 2, flags='aligned, c_contiguous'), # dataset
ndpointer(float64, ndim=2, flags='aligned, c_contiguous'), # dataset
c_int, # rows
c_int, # cols
c_int, # clusters
......@@ -373,8 +382,8 @@ flann.free_index[%(numpy)s] = flannlib.flann_free_index_%(C)s
""")
def ensure_2d_array(array, flags, **kwargs):
array = require(array, requirements = flags, **kwargs)
if len(array.shape) == 1:
array = array.reshape(-1,array.size)
return array
def ensure_2d_array(arr, flags, **kwargs):
arr = require(arr, requirements=flags, **kwargs)
if len(arr.shape) == 1:
arr = arr.reshape(-1, arr.size)
return arr
......@@ -24,50 +24,59 @@
#(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
#THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from pyflann.flann_ctypes import *
from pyflann.exceptions import *
#from pyflann.flann_ctypes import * # NOQA
import sys
from ctypes import pointer, c_float, byref, c_char_p
from pyflann.flann_ctypes import (flannlib, FLANNParameters, allowed_types,
ensure_2d_array, default_flags, flann)
import numpy as np
from pyflann.exceptions import FLANNException
import numpy.random as _rn
index_type = int32
index_type = np.int32
def set_distance_type(distance_type, order = 0):
def set_distance_type(distance_type, order=0):
"""
Sets the distance type used. Possible values: euclidean, manhattan, minkowski, max_dist,
hik, hellinger, cs, kl.
"""
distance_translation = { "euclidean" : 1,
"manhattan" : 2,
"minkowski" : 3,
"max_dist" : 4,
"hik" : 5,
"hellinger" : 6,
"chi_square" : 7,
"cs" : 7,
"kullback_leibler" : 8,
"kl" : 8,
distance_translation = {'euclidean': 1,
'manhattan': 2,
'minkowski': 3,
'max_dist': 4,
'hik': 5,
'hellinger': 6,
'chi_square': 7,
'cs': 7,
'kullback_leibler': 8,
'kl': 8,
}
if type(distance_type)==str:
if isinstance(distance_type, str):
distance_type = distance_translation[distance_type]
flannlib.flann_set_distance_type(distance_type,order)
flannlib.flann_set_distance_type(distance_type, order)
def to_bytes(string):
if sys.hexversion > 0x03000000:
return bytes(string,'utf-8')
return bytes(string, 'utf-8')
return string
# This class is derived from an initial implementation by Hoyt Koepke (hoytak@cs.ubc.ca)
class FLANN:
# This class is derived from an initial implementation by Hoyt Koepke
# (hoytak@cs.ubc.ca)
class FLANN(object):
"""
This class defines a python interface to the FLANN lirary.
"""
__rn_gen = _rn.RandomState()
_as_parameter_ = property( lambda self: self.__curindex )
_as_parameter_ = property(lambda self: self.__curindex)
def __init__(self, **kwargs):
"""
......@@ -88,51 +97,50 @@ class FLANN:
def __del__(self):
self.delete_index()
################################################################################
##########################################################################
# actual workhorse functions
def nn(self, pts, qpts, num_neighbors = 1, **kwargs):
def nn(self, pts, qpts, num_neighbors=1, **kwargs):
"""
Returns the num_neighbors nearest points in dataset for each point
in testset.
"""
if not pts.dtype.type in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype)
if pts.dtype.type not in allowed_types:
raise FLANNException('Cannot handle type: %s' % pts.dtype)
if not qpts.dtype.type in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype)
if qpts.dtype.type not in allowed_types:
raise FLANNException('Cannot handle type: %s' % pts.dtype)
if pts.dtype != qpts.dtype:
raise FLANNException("Data and query must have the same type")
raise FLANNException('Data and query must have the same type')
pts = ensure_2d_array(pts,default_flags)
qpts = ensure_2d_array(qpts,default_flags)
pts = ensure_2d_array(pts, default_flags)
qpts = ensure_2d_array(qpts, default_flags)
npts, dim = pts.shape
nqpts = qpts.shape[0]
assert(qpts.shape[1] == dim)
assert(npts >= num_neighbors)
assert qpts.shape[1] == dim, 'data and query must have the same dims'
assert npts >= num_neighbors, 'more neighbors than there are points'
result = empty( (nqpts, num_neighbors), dtype=index_type)
if pts.dtype==float64:
dists = empty( (nqpts, num_neighbors), dtype=float64)
result = np.empty((nqpts, num_neighbors), dtype=index_type)
if pts.dtype == np.float64:
dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
else:
dists = empty( (nqpts, num_neighbors), dtype=float32)
dists = np.empty((nqpts, num_neighbors), dtype=np.float32)
self.__flann_parameters.update(kwargs)
flann.find_nearest_neighbors[pts.dtype.type](pts, npts, dim,
qpts, nqpts, result, dists, num_neighbors,
flann.find_nearest_neighbors[
pts.dtype.type](
pts, npts, dim, qpts, nqpts, result, dists, num_neighbors,
pointer(self.__flann_parameters))
if num_neighbors == 1:
return (result.reshape( nqpts ), dists.reshape(nqpts))
return (result.reshape(nqpts), dists.reshape(nqpts))
else:
return (result,dists)
return (result, dists)
def build_index(self, pts, **kwargs):
"""
......@@ -143,80 +151,85 @@ class FLANN:
the nearest neighbors in this index.
pts is a 2d numpy array or matrix. All the computation is done
in float32 type, but pts may be any type that is convertable
to float32.
in np.float32 type, but pts may be any type that is convertable
to np.float32.
"""
if not pts.dtype.type in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype)
if pts.dtype.type not in allowed_types:
raise FLANNException('Cannot handle type: %s' % pts.dtype)
pts = ensure_2d_array(pts,default_flags)
pts = ensure_2d_array(pts, default_flags)
npts, dim = pts.shape
self.__ensureRandomSeed(kwargs)
self.__flann_parameters.update(kwargs)
if self.__curindex != None:
flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters))
if self.__curindex is not None:
flann.free_index[self.__curindex_type](
self.__curindex, pointer(self.__flann_parameters))
self.__curindex = None
speedup = c_float(0)
self.__curindex = flann.build_index[pts.dtype.type](pts, npts, dim, byref(speedup), pointer(self.__flann_parameters))
self.__curindex = flann.build_index[pts.dtype.type](
pts, npts, dim, byref(speedup), pointer(self.__flann_parameters))
self.__curindex_data = pts
self.__curindex_type = pts.dtype.type
params = dict(self.__flann_parameters)
params["speedup"] = speedup.value
params['speedup'] = speedup.value
return params
def save_index(self, filename):
"""
This saves the index to a disk file.
"""
if self.__curindex != None:
flann.save_index[self.__curindex_type](self.__curindex, c_char_p(to_bytes(filename)))
if self.__curindex is not None:
flann.save_index[self.__curindex_type](
self.__curindex, c_char_p(to_bytes(filename)))
def load_index(self, filename, pts):
"""
Loads an index previously saved to disk.
"""
if not pts.dtype.type in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype)
if pts.dtype.type not in allowed_types:
raise FLANNException('Cannot handle type: %s' % pts.dtype)
pts = ensure_2d_array(pts,default_flags)
pts = ensure_2d_array(pts, default_flags)
npts, dim = pts.shape
if self.__curindex != None:
flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters))
if self.__curindex is not None:
flann.free_index[self.__curindex_type](
self.__curindex, pointer(self.__flann_parameters))
self.__curindex = None
self.__curindex_data = None
self.__curindex_type = None
self.__curindex = flann.load_index[pts.dtype.type](c_char_p(to_bytes(filename)), pts, npts, dim)
self.__curindex = flann.load_index[pts.dtype.type](
c_char_p(to_bytes(filename)), pts, npts, dim)
self.__curindex_data = pts
self.__curindex_type = pts.dtype.type
def nn_index(self, qpts, num_neighbors = 1, **kwargs):
def nn_index(self, qpts, num_neighbors=1, **kwargs):
"""
For each point in querypts, (which may be a single point), it
returns the num_neighbors nearest points in the index built by
calling build_index.
"""
if self.__curindex == None:
raise FLANNException("build_index(...) method not called first or current index deleted.")
if self.__curindex is None:
raise FLANNException(
'build_index(...) method not called first or current index deleted.')
if not qpts.dtype.type in allowed_types:
raise FLANNException("Cannot handle type: %s"%qpts.dtype)
if qpts.dtype.type not in allowed_types:
raise FLANNException('Cannot handle type: %s' % qpts.dtype)
if self.__curindex_type != qpts.dtype.type:
raise FLANNException("Index and query must have the same type")
raise FLANNException('Index and query must have the same type')
qpts = ensure_2d_array(qpts,default_flags)
qpts = ensure_2d_array(qpts, default_flags)
npts, dim = self.__curindex_data.shape
......@@ -225,56 +238,56 @@ class FLANN:
nqpts = qpts.shape[0]
assert(qpts.shape[1] == dim)
assert(npts >= num_neighbors)
assert qpts.shape[1] == dim, 'data and query must have the same dims'
assert npts >= num_neighbors, 'more neighbors than there are points'
result = empty( (nqpts, num_neighbors), dtype=index_type)
if self.__curindex_type==float64:
dists = empty( (nqpts, num_neighbors), dtype=float64)
result = np.empty((nqpts, num_neighbors), dtype=index_type)
if self.__curindex_type == np.float64:
dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
else:
dists = empty( (nqpts, num_neighbors), dtype=float32)
dists = np.empty((nqpts, num_neighbors), dtype=np.float32)
self.__flann_parameters.update(kwargs)
flann.find_nearest_neighbors_index[self.__curindex_type](self.__curindex,
qpts, nqpts,
result, dists, num_neighbors,
flann.find_nearest_neighbors_index[
self.__curindex_type](
self.__curindex, qpts, nqpts, result, dists, num_neighbors,
pointer(self.__flann_parameters))
if num_neighbors == 1:
return (result.reshape( nqpts ), dists.reshape( nqpts ))
return (result.reshape(nqpts), dists.reshape(nqpts))
else:
return (result,dists)
return (result, dists)
def nn_radius(self, query, radius, **kwargs):
if self.__curindex == None:
raise FLANNException("build_index(...) method not called first or current index deleted.")
if self.__curindex is None:
raise FLANNException(
'build_index(...) method not called first or current index deleted.')
if not query.dtype.type in allowed_types:
raise FLANNException("Cannot handle type: %s"%query.dtype)
if query.dtype.type not in allowed_types:
raise FLANNException('Cannot handle type: %s' % query.dtype)
if self.__curindex_type != query.dtype.type:
raise FLANNException("Index and query must have the same type")
raise FLANNException('Index and query must have the same type')
npts, dim = self.__curindex_data.shape
assert(query.shape[0]==dim)
assert query.shape[0] == dim, 'data and query must have the same dims'
result = empty( npts, dtype=index_type)
if self.__curindex_type==float64:
dists = empty( npts, dtype=float64)
result = np.empty(npts, dtype=index_type)
if self.__curindex_type == np.float64:
dists = np.empty(npts, dtype=np.float64)
else:
dists = empty( npts, dtype=float32)
dists = np.empty(npts, dtype=np.float32)
self.__flann_parameters.update(kwargs)
nn = flann.radius_search[self.__curindex_type](self.__curindex, query,
result, dists, npts,
radius, pointer(self.__flann_parameters))
nn = flann.radius_search[
self.__curindex_type](
self.__curindex, query, result, dists, npts, radius,
pointer(self.__flann_parameters))
return (result[0:nn],dists[0:nn])
return (result[0:nn], dists[0:nn])
def delete_index(self, **kwargs):
"""
......@@ -284,16 +297,17 @@ class FLANN:
self.__flann_parameters.update(kwargs)
if self.__curindex != None:
flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters))
if self.__curindex is not None:
flann.free_index[self.__curindex_type](
self.__curindex, pointer(self.__flann_parameters))
self.__curindex = None
self.__curindex_data = None
##########################################################################################
##########################################################################
# Clustering functions
def kmeans(self, pts, num_clusters, max_iterations = None,
dtype = None, **kwargs):
def kmeans(self, pts, num_clusters, max_iterations=None,
dtype=None, **kwargs):
"""
Runs kmeans on pts with num_clusters centroids. Returns a
numpy array of size num_clusters x dim.
......@@ -311,18 +325,18 @@ class FLANN:
raise FLANNException('num_clusters must be an integer >= 1')
if num_clusters == 1:
if dtype == None or dtype == pts.dtype:
return mean(pts, 0).reshape(1, pts.shape[1])
if dtype is None or dtype == pts.dtype:
return np.mean(pts, 0).reshape(1, pts.shape[1])
else:
return dtype(mean(pts, 0).reshape(1, pts.shape[1]))
return dtype(np.mean(pts, 0).reshape(1, pts.shape[1]))
return self.hierarchical_kmeans(pts, int(num_clusters), 1,
max_iterations,
dtype, **kwargs)
def hierarchical_kmeans(self, pts, branch_size, num_branches,
max_iterations = None,
dtype = None, **kwargs):
max_iterations=None,
dtype=None, **kwargs):
"""
Clusters the data by using multiple runs of kmeans to
recursively partition the dataset. The number of resulting
......@@ -339,8 +353,8 @@ class FLANN:
# First verify the paremeters are sensible.
if not pts.dtype.type in allowed_types:
raise FLANNException("Cannot handle type: %s"%pts.dtype)
if pts.dtype.type not in allowed_types:
raise FLANNException('Cannot handle type: %s' % pts.dtype)
if int(branch_size) != branch_size or branch_size < 2:
raise FLANNException('branch_size must be an integer >= 2.')
......@@ -352,51 +366,46 @@ class FLANN:
num_branches = int(num_branches)
if max_iterations == None:
if max_iterations is None:
max_iterations = -1
else:
max_iterations = int(max_iterations)
# init the arrays and starting values
pts = ensure_2d_array(pts,default_flags)
pts = ensure_2d_array(pts, default_flags)
npts, dim = pts.shape
num_clusters = (branch_size-1)*num_branches+1;
num_clusters = (branch_size - 1) * num_branches + 1
if pts.dtype.type == float64:
result = empty( (num_clusters, dim), dtype=float64)
if pts.dtype.type == np.float64:
result = np.empty((num_clusters, dim), dtype=np.float64)
else:
result = empty( (num_clusters, dim), dtype=float32)
result = np.empty((num_clusters, dim), dtype=np.float32)
# set all the parameters appropriately
self.__ensureRandomSeed(kwargs)
params = {"iterations" : max_iterations,
"algorithm" : 'kmeans',
"branching" : branch_size,
"random_seed" : kwargs['random_seed']}
params = {'iterations': max_iterations,
'algorithm': 'kmeans',
'branching': branch_size,
'random_seed': kwargs['random_seed']}
self.__flann_parameters.update(params)
numclusters = flann.compute_cluster_centers[pts.dtype.type](pts, npts, dim,
num_clusters, result,
numclusters = flann.compute_cluster_centers[pts.dtype.type](
pts, npts, dim, num_clusters, result,
pointer(self.__flann_parameters))
if numclusters <= 0:
raise FLANNException('Error occured during clustering procedure.')
if dtype == None:
if dtype is None:
return result
else:
return dtype(result)
##########################################################################################
##########################################################################
# internal bookkeeping functions
def __ensureRandomSeed(self, kwargs):
if not 'random_seed' in kwargs:
kwargs['random_seed'] = self.__rn_gen.randint(2**30)
if 'random_seed' not in kwargs:
kwargs['random_seed'] = self.__rn_gen.randint(2 ** 30)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册