提交 c2063f73 编写于 作者: M Marius Muja

Some API breaking change to flann::Matrix

上级 e6336baa
...@@ -27,10 +27,10 @@ int main(int argc, char** argv) ...@@ -27,10 +27,10 @@ int main(int argc, char** argv)
flann::save_to_file(indices,"result.hdf5","result"); flann::save_to_file(indices,"result.hdf5","result");
delete[] dataset.data; delete[] dataset.ptr();
delete[] query.data; delete[] query.ptr();
delete[] indices.data; delete[] indices.ptr();
delete[] dists.data; delete[] dists.ptr();
return 0; return 0;
} }
...@@ -470,9 +470,9 @@ private: ...@@ -470,9 +470,9 @@ private:
} }
} }
delete[] gt_matches_.data; delete[] gt_matches_.ptr();
delete[] testDataset_.data; delete[] testDataset_.ptr();
delete[] sampledDataset_.data; delete[] sampledDataset_.ptr();
return bestParams; return bestParams;
} }
...@@ -544,8 +544,8 @@ private: ...@@ -544,8 +544,8 @@ private:
speedup = linear / searchTime; speedup = linear / searchTime;
delete[] gt_matches.data; delete[] gt_matches.ptr();
delete[] testDataset.data; delete[] testDataset.ptr();
} }
return speedup; return speedup;
......
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
*/ */
~KDTreeSingleIndex() ~KDTreeSingleIndex()
{ {
if (reorder_) delete[] data_.data; if (reorder_) delete[] data_.ptr();
} }
/** /**
...@@ -117,7 +117,6 @@ public: ...@@ -117,7 +117,6 @@ public:
root_node_ = divideTree(0, size_, root_bbox_ ); // construct the tree root_node_ = divideTree(0, size_, root_bbox_ ); // construct the tree
if (reorder_) { if (reorder_) {
delete[] data_.data;
data_ = flann::Matrix<ElementType>(new ElementType[size_*dim_], size_, dim_); data_ = flann::Matrix<ElementType>(new ElementType[size_*dim_], size_, dim_);
for (size_t i=0; i<size_; ++i) { for (size_t i=0; i<size_; ++i) {
for (size_t j=0; j<dim_; ++j) { for (size_t j=0; j<dim_; ++j) {
......
...@@ -815,7 +815,7 @@ private: ...@@ -815,7 +815,7 @@ private:
start=end; start=end;
} }
delete[] dcenters.data; delete[] dcenters.ptr();
delete[] centers; delete[] centers;
delete[] radiuses; delete[] radiuses;
delete[] count; delete[] count;
......
...@@ -105,9 +105,8 @@ public: ...@@ -105,9 +105,8 @@ public:
void findNeighbors(ResultSet<DistanceType>& resultSet, const ElementType* vec, const SearchParams& /*searchParams*/) void findNeighbors(ResultSet<DistanceType>& resultSet, const ElementType* vec, const SearchParams& /*searchParams*/)
{ {
ElementType* data = dataset_.data; for (size_t i = 0; i < dataset_.rows; ++i) {
for (size_t i = 0; i < dataset_.rows; ++i, data += dataset_.cols) { DistanceType dist = distance_(dataset_[i], vec, dataset_.cols);
DistanceType dist = distance_(data, vec, dataset_.cols);
resultSet.addPoint(dist, i); resultSet.addPoint(dist, i);
} }
} }
......
...@@ -155,6 +155,7 @@ enum flann_distance_t ...@@ -155,6 +155,7 @@ enum flann_distance_t
enum flann_datatype_t enum flann_datatype_t
{ {
FLANN_NONE = -1,
FLANN_INT8 = 0, FLANN_INT8 = 0,
FLANN_INT16 = 1, FLANN_INT16 = 1,
FLANN_INT32 = 2, FLANN_INT32 = 2,
......
...@@ -80,7 +80,7 @@ NNIndex<Distance>* load_saved_index(const Matrix<typename Distance::ElementType> ...@@ -80,7 +80,7 @@ NNIndex<Distance>* load_saved_index(const Matrix<typename Distance::ElementType>
return NULL; return NULL;
} }
IndexHeader header = load_header(fin); IndexHeader header = load_header(fin);
if (header.data_type != Datatype<ElementType>::type()) { if (header.data_type != flann_datatype<ElementType>::value) {
throw FLANNException("Datatype of saved index is different than of the one to be created."); throw FLANNException("Datatype of saved index is different than of the one to be created.");
} }
if ((size_t(header.rows) != dataset.rows)||(size_t(header.cols) != dataset.cols)) { if ((size_t(header.rows) != dataset.rows)||(size_t(header.cols) != dataset.cols)) {
......
...@@ -46,6 +46,98 @@ public: ...@@ -46,6 +46,98 @@ public:
FLANNException(const std::string& message) : std::runtime_error(message) { } FLANNException(const std::string& message) : std::runtime_error(message) { }
}; };
inline size_t flann_datatype_size(flann_datatype_t type)
{
switch (type) {
case FLANN_INT8:
return 1;
break;
case FLANN_INT16:
return 2;
break;
case FLANN_INT32:
return 4;
break;
case FLANN_INT64:
return 8;
break;
case FLANN_UINT8:
return 1;
break;
case FLANN_UINT16:
return 2;
break;
case FLANN_UINT32:
return 4;
break;
case FLANN_UINT64:
return 8;
break;
case FLANN_FLOAT32:
return 4;
break;
case FLANN_FLOAT64:
return 8;
break;
default:
return 1;
}
}
template <typename T>
struct flann_datatype
{
static const flann_datatype_t value = FLANN_NONE;
};
template<>
struct flann_datatype<char>
{
static const flann_datatype_t value = FLANN_INT8;
};
template<>
struct flann_datatype<short>
{
static const flann_datatype_t value = FLANN_INT16;
};
template<>
struct flann_datatype<int>
{
static const flann_datatype_t value = FLANN_INT32;
};
template<>
struct flann_datatype<unsigned char>
{
static const flann_datatype_t value = FLANN_UINT8;
};
template<>
struct flann_datatype<unsigned short>
{
static const flann_datatype_t value = FLANN_UINT16;
};
template<>
struct flann_datatype<unsigned int>
{
static const flann_datatype_t value = FLANN_UINT32;
};
template<>
struct flann_datatype<float>
{
static const flann_datatype_t value = FLANN_FLOAT32;
};
template<>
struct flann_datatype<double>
{
static const flann_datatype_t value = FLANN_FLOAT64;
};
} }
......
...@@ -115,7 +115,7 @@ void save_to_file(const flann::Matrix<T>& dataset, const std::string& filename, ...@@ -115,7 +115,7 @@ void save_to_file(const flann::Matrix<T>& dataset, const std::string& filename,
} }
CHECK_ERROR(dataset_id,"Error creating or opening dataset in file."); CHECK_ERROR(dataset_id,"Error creating or opening dataset in file.");
status = H5Dwrite(dataset_id, get_hdf5_type<T>(), memspace_id, space_id, H5P_DEFAULT, dataset.data ); status = H5Dwrite(dataset_id, get_hdf5_type<T>(), memspace_id, space_id, H5P_DEFAULT, dataset.ptr() );
CHECK_ERROR(status, "Error writing to dataset"); CHECK_ERROR(status, "Error writing to dataset");
H5Sclose(memspace_id); H5Sclose(memspace_id);
......
...@@ -52,22 +52,20 @@ struct SearchResults ...@@ -52,22 +52,20 @@ struct SearchResults
ar& indices.rows; ar& indices.rows;
ar& indices.cols; ar& indices.cols;
if (Archive::is_loading::value) { if (Archive::is_loading::value) {
indices.stride = indices.cols; indices = Matrix<int>(new int[indices.rows*indices.cols], indices.rows, indices.cols);
indices.data = new int[indices.rows*indices.cols];
} }
ar& boost::serialization::make_array(indices.data, indices.rows*indices.cols); ar& boost::serialization::make_array(indices.ptr(), indices.rows*indices.cols);
if (Archive::is_saving::value) { if (Archive::is_saving::value) {
delete[] indices.data; delete[] indices.ptr();
} }
ar& dists.rows; ar& dists.rows;
ar& dists.cols; ar& dists.cols;
if (Archive::is_loading::value) { if (Archive::is_loading::value) {
dists.stride = dists.cols; dists = Matrix<DistanceType>(new DistanceType[dists.rows*dists.cols], dists.rows, dists.cols);
dists.data = new DistanceType[dists.rows*dists.cols];
} }
ar& boost::serialization::make_array(dists.data, dists.rows*dists.cols); ar& boost::serialization::make_array(dists.ptr(), dists.rows*dists.cols);
if (Archive::is_saving::value) { if (Archive::is_saving::value) {
delete[] dists.data; delete[] dists.ptr();
} }
} }
}; };
...@@ -101,10 +99,10 @@ struct ResultsMerger ...@@ -101,10 +99,10 @@ struct ResultsMerger
} }
} }
} }
delete[] a.indices.data; delete[] a.indices.ptr();
delete[] a.dists.data; delete[] a.dists.ptr();
delete[] b.indices.data; delete[] b.indices.ptr();
delete[] b.dists.data; delete[] b.dists.ptr();
return results; return results;
} }
}; };
...@@ -191,7 +189,7 @@ template<typename Distance> ...@@ -191,7 +189,7 @@ template<typename Distance>
Index<Distance>::~Index() Index<Distance>::~Index()
{ {
delete flann_index; delete flann_index;
delete[] dataset.data; delete[] dataset.ptr();
} }
template<typename Distance> template<typename Distance>
...@@ -222,8 +220,8 @@ void Index<Distance>::knnSearch(const flann::Matrix<ElementType>& queries, flann ...@@ -222,8 +220,8 @@ void Index<Distance>::knnSearch(const flann::Matrix<ElementType>& queries, flann
dists[i][j] = results.dists[i][j]; dists[i][j] = results.dists[i][j];
} }
} }
delete[] results.indices.data; delete[] results.indices.ptr();
delete[] results.dists.data; delete[] results.dists.ptr();
} }
} }
...@@ -255,8 +253,8 @@ int Index<Distance>::radiusSearch(const flann::Matrix<ElementType>& query, flann ...@@ -255,8 +253,8 @@ int Index<Distance>::radiusSearch(const flann::Matrix<ElementType>& query, flann
dists[i][j] = results.dists[i][j]; dists[i][j] = results.dists[i][j];
} }
} }
delete[] results.indices.data; delete[] results.indices.ptr();
delete[] results.dists.data; delete[] results.dists.ptr();
} }
return 0; return 0;
} }
......
...@@ -37,29 +37,66 @@ ...@@ -37,29 +37,66 @@
namespace flann namespace flann
{ {
typedef unsigned char uchar;
class Matrix_
{
public:
Matrix_() : rows(0), cols(0), stride(0), data(NULL)
{
};
Matrix_(void* data_, size_t rows_, size_t cols_, flann_datatype_t type, size_t stride_ = 0) :
rows(rows_), cols(cols_), stride(stride_)
{
data = static_cast<uchar*>(data_);
if (stride==0) stride = flann_datatype_size(type)*cols;
}
/**
* Operator that returns a (pointer to a) row of the data.
*/
inline void* operator[](size_t index) const
{
return data+index*stride;
}
void* ptr() const
{
return data;
}
size_t rows;
size_t cols;
size_t stride;
flann_datatype_t type;
protected:
uchar* data;
};
/** /**
* Class that implements a simple rectangular matrix stored in a memory buffer and * Class that implements a simple rectangular matrix stored in a memory buffer and
* provides convenient matrix-like access using the [] operators. * provides convenient matrix-like access using the [] operators.
*
* This class has the same memory structure as the un-templated class flann::Matrix_ and
* it's directly convertible from it.
*/ */
template <typename T> template <typename T>
class Matrix class Matrix : public Matrix_
{ {
public: public:
typedef T type; typedef T type;
size_t rows; Matrix() : Matrix_()
size_t cols;
size_t stride;
T* data;
Matrix() : rows(0), cols(0), stride(0), data(NULL)
{ {
} }
Matrix(T* data_, size_t rows_, size_t cols_, size_t stride_ = 0) : Matrix(T* data_, size_t rows_, size_t cols_, size_t stride_ = 0) :
rows(rows_), cols(cols_), stride(stride_), data(data_) Matrix_(data_, rows_, cols_, flann_datatype<T>::value, stride_)
{ {
if (stride==0) stride = cols;
} }
/** /**
...@@ -76,42 +113,18 @@ public: ...@@ -76,42 +113,18 @@ public:
/** /**
* Operator that returns a (pointer to a) row of the data. * Operator that returns a (pointer to a) row of the data.
*/ */
T* operator[](size_t index) const inline T* operator[](size_t index) const
{ {
return data+index*stride; return reinterpret_cast<T*>(static_cast<uchar*>(Matrix_::data)+index*stride);
// return (T*)(Matrix_::operator [](index));
} }
};
T* ptr() const
class UntypedMatrix
{
public:
size_t rows;
size_t cols;
void* data;
flann_datatype_t type;
UntypedMatrix(void* data_, long rows_, long cols_) :
rows(rows_), cols(cols_), data(data_)
{
}
~UntypedMatrix()
{ {
} return reinterpret_cast<T*>(Matrix_::data);
template<typename T>
Matrix<T> as()
{
return Matrix<T>((T*)data, rows, cols);
} }
}; };
} }
#endif //FLANN_DATASET_H_ #endif //FLANN_DATASET_H_
...@@ -45,26 +45,6 @@ ...@@ -45,26 +45,6 @@
namespace flann namespace flann
{ {
template <typename T>
struct Datatype {};
template<>
struct Datatype<char> { static flann_datatype_t type() { return FLANN_INT8; } };
template<>
struct Datatype<short> { static flann_datatype_t type() { return FLANN_INT16; } };
template<>
struct Datatype<int> { static flann_datatype_t type() { return FLANN_INT32; } };
template<>
struct Datatype<unsigned char> { static flann_datatype_t type() { return FLANN_UINT8; } };
template<>
struct Datatype<unsigned short> { static flann_datatype_t type() { return FLANN_UINT16; } };
template<>
struct Datatype<unsigned int> { static flann_datatype_t type() { return FLANN_UINT32; } };
template<>
struct Datatype<float> { static flann_datatype_t type() { return FLANN_FLOAT32; } };
template<>
struct Datatype<double> { static flann_datatype_t type() { return FLANN_FLOAT64; } };
/** /**
* Structure representing the index header. * Structure representing the index header.
*/ */
...@@ -92,7 +72,7 @@ void save_header(FILE* stream, const NNIndex<Distance>& index) ...@@ -92,7 +72,7 @@ void save_header(FILE* stream, const NNIndex<Distance>& index)
strcpy(header.signature, FLANN_SIGNATURE_); strcpy(header.signature, FLANN_SIGNATURE_);
memset(header.version, 0, sizeof(header.version)); memset(header.version, 0, sizeof(header.version));
strcpy(header.version, FLANN_VERSION_); strcpy(header.version, FLANN_VERSION_);
header.data_type = Datatype<typename Distance::ElementType>::type(); header.data_type = flann_datatype<typename Distance::ElementType>::value;
header.index_type = index.getType(); header.index_type = index.getType();
header.rows = index.size(); header.rows = index.size();
header.cols = index.veclen(); header.cols = index.veclen();
...@@ -120,7 +100,6 @@ inline IndexHeader load_header(FILE* stream) ...@@ -120,7 +100,6 @@ inline IndexHeader load_header(FILE* stream)
} }
return header; return header;
} }
...@@ -134,7 +113,7 @@ template<typename T> ...@@ -134,7 +113,7 @@ template<typename T>
void save_value(FILE* stream, const flann::Matrix<T>& value) void save_value(FILE* stream, const flann::Matrix<T>& value)
{ {
fwrite(&value, sizeof(value),1, stream); fwrite(&value, sizeof(value),1, stream);
fwrite(value.data, sizeof(T),value.rows*value.cols, stream); fwrite(value.ptr(), sizeof(T),value.rows*value.cols, stream);
} }
template<typename T> template<typename T>
...@@ -161,8 +140,8 @@ void load_value(FILE* stream, flann::Matrix<T>& value) ...@@ -161,8 +140,8 @@ void load_value(FILE* stream, flann::Matrix<T>& value)
if (read_cnt != 1) { if (read_cnt != 1) {
throw FLANNException("Cannot read from file"); throw FLANNException("Cannot read from file");
} }
value.data = new T[value.rows*value.cols]; value = Matrix<T>(new T[value.rows*value.cols], value.rows, value.cols);
read_cnt = fread(value.data, sizeof(T), value.rows*value.cols, stream); read_cnt = fread(value.ptr(), sizeof(T), value.rows*value.cols, stream);
if (read_cnt != value.rows*value.cols) { if (read_cnt != value.rows*value.cols) {
throw FLANNException("Cannot read from file"); throw FLANNException("Cannot read from file");
} }
......
...@@ -77,11 +77,11 @@ protected: ...@@ -77,11 +77,11 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] match.data; delete[] match.ptr();
delete[] dists.data; delete[] dists.ptr();
delete[] indices.data; delete[] indices.ptr();
} }
int GetNN() { return nn; } int GetNN() { return nn; }
...@@ -162,12 +162,12 @@ protected: ...@@ -162,12 +162,12 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] dists_single.data; delete[] dists_single.ptr();
delete[] indices_single.data; delete[] indices_single.ptr();
delete[] dists_multi.data; delete[] dists_multi.ptr();
delete[] indices_multi.data; delete[] indices_multi.ptr();
} }
int GetNN() { return nn; } int GetNN() { return nn; }
...@@ -268,12 +268,12 @@ protected: ...@@ -268,12 +268,12 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] dists_single.data; delete[] dists_single.ptr();
delete[] indices_single.data; delete[] indices_single.ptr();
delete[] dists_multi.data; delete[] dists_multi.ptr();
delete[] indices_multi.data; delete[] indices_multi.ptr();
} }
float GetRadius() { return radius; } float GetRadius() { return radius; }
......
...@@ -108,11 +108,11 @@ protected: ...@@ -108,11 +108,11 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] match.data; delete[] match.ptr();
delete[] dists.data; delete[] dists.ptr();
delete[] indices.data; delete[] indices.ptr();
} }
}; };
...@@ -192,11 +192,11 @@ protected: ...@@ -192,11 +192,11 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] match.data; delete[] match.ptr();
delete[] dists.data; delete[] dists.ptr();
delete[] indices.data; delete[] indices.ptr();
} }
}; };
...@@ -275,11 +275,11 @@ protected: ...@@ -275,11 +275,11 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] match.data; delete[] match.ptr();
delete[] dists.data; delete[] dists.ptr();
delete[] indices.data; delete[] indices.ptr();
} }
}; };
...@@ -436,11 +436,11 @@ protected: ...@@ -436,11 +436,11 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] match.data; delete[] match.ptr();
delete[] dists.data; delete[] dists.ptr();
delete[] indices.data; delete[] indices.ptr();
} }
}; };
...@@ -519,11 +519,11 @@ protected: ...@@ -519,11 +519,11 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] match.data; delete[] match.ptr();
delete[] dists.data; delete[] dists.ptr();
delete[] indices.data; delete[] indices.ptr();
} }
}; };
...@@ -550,7 +550,7 @@ TEST_F(Flann_3D, KDTreeSingleTest_Padded) ...@@ -550,7 +550,7 @@ TEST_F(Flann_3D, KDTreeSingleTest_Padded)
{ {
flann::Matrix<float> data_padded; flann::Matrix<float> data_padded;
flann::load_from_file(data_padded, "cloud.h5", "dataset_padded"); flann::load_from_file(data_padded, "cloud.h5", "dataset_padded");
flann::Matrix<float> data2(data_padded.data, data_padded.rows, 3, data_padded.cols); flann::Matrix<float> data2(data_padded.ptr(), data_padded.rows, 3, data_padded.cols*sizeof(float));
flann::Index<L2_Simple<float> > index(data2, flann::KDTreeSingleIndexParams(12, false)); flann::Index<L2_Simple<float> > index(data2, flann::KDTreeSingleIndexParams(12, false));
start_timer("Building kd-tree index..."); start_timer("Building kd-tree index...");
...@@ -565,7 +565,7 @@ TEST_F(Flann_3D, KDTreeSingleTest_Padded) ...@@ -565,7 +565,7 @@ TEST_F(Flann_3D, KDTreeSingleTest_Padded)
EXPECT_GE(precision, 0.99); EXPECT_GE(precision, 0.99);
printf("Precision: %g\n", precision); printf("Precision: %g\n", precision);
delete[] data_padded.data; delete[] data_padded.ptr();
} }
TEST_F(Flann_3D, SavedTest) TEST_F(Flann_3D, SavedTest)
...@@ -658,12 +658,12 @@ protected: ...@@ -658,12 +658,12 @@ protected:
void TearDown() void TearDown()
{ {
delete[] data.data; delete[] data.ptr();
delete[] query.data; delete[] query.ptr();
delete[] match.data; delete[] match.ptr();
delete[] gt_dists.data; delete[] gt_dists.ptr();
delete[] dists.data; delete[] dists.ptr();
delete[] indices.data; delete[] indices.ptr();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册