diff --git a/modules/flann/include/opencv2/flann/dist.h b/modules/flann/include/opencv2/flann/dist.h index 80ae2dc9162bcf9ff0df587f4e4968db80839fcd..2afceb88930b46d4763b468e128381a733b1437b 100644 --- a/modules/flann/include/opencv2/flann/dist.h +++ b/modules/flann/include/opencv2/flann/dist.h @@ -812,6 +812,66 @@ struct ZeroIterator }; + +/* + * Depending on processed distances, some of them are already squared (e.g. L2) + * and some are not (e.g.Hamming). In KMeans++ for instance we want to be sure + * we are working on ^2 distances, thus following templates to ensure that. + */ +template +struct squareDistance +{ + typedef typename Distance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist*dist; } +}; + + +template +struct squareDistance, ElementType> +{ + typedef typename L2_Simple::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + +template +struct squareDistance, ElementType> +{ + typedef typename L2::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + + +template +struct squareDistance, ElementType> +{ + typedef typename MinkowskiDistance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + +template +struct squareDistance, ElementType> +{ + typedef typename HellingerDistance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + +template +struct squareDistance, ElementType> +{ + typedef typename ChiSquareDistance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + + +template +typename Distance::ResultType ensureSquareDistance( typename Distance::ResultType dist ) +{ + typedef typename Distance::ElementType ElementType; + + squareDistance dummy; + return dummy( dist ); +} + } #endif //OPENCV_FLANN_DIST_H_ diff --git a/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h b/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h index b511ee9089b6786334e3dcf33030d2a391eaf11c..54cb8e55fcae04199d7ae2342c25637d056d6971 100644 --- a/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h +++ b/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h @@ -210,8 +210,11 @@ private: assert(index >=0 && index < n); centers[0] = dsindices[index]; + // Computing distance^2 will have the advantage of even higher probability further to pick new centers + // far from previous centers (and this complies to "k-means++: the advantages of careful seeding" article) for (int i = 0; i < n; i++) { closestDistSq[i] = distance(dataset[dsindices[i]], dataset[dsindices[index]], dataset.cols); + closestDistSq[i] = ensureSquareDistance( closestDistSq[i] ); currentPot += closestDistSq[i]; } @@ -237,7 +240,10 @@ private: // Compute the new potential double newPot = 0; - for (int i = 0; i < n; i++) newPot += std::min( distance(dataset[dsindices[i]], dataset[dsindices[index]], dataset.cols), closestDistSq[i] ); + for (int i = 0; i < n; i++) { + DistanceType dist = distance(dataset[dsindices[i]], dataset[dsindices[index]], dataset.cols); + newPot += std::min( ensureSquareDistance(dist), closestDistSq[i] ); + } // Store the best result if ((bestNewPot < 0)||(newPot < bestNewPot)) { @@ -249,7 +255,10 @@ private: // Add the appropriate center centers[centerCount] = dsindices[bestNewIndex]; currentPot = bestNewPot; - for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance(dataset[dsindices[i]], dataset[dsindices[bestNewIndex]], dataset.cols), closestDistSq[i] ); + for (int i = 0; i < n; i++) { + DistanceType dist = distance(dataset[dsindices[i]], dataset[dsindices[bestNewIndex]], dataset.cols); + closestDistSq[i] = std::min( ensureSquareDistance(dist), closestDistSq[i] ); + } } centers_length = centerCount; diff --git a/modules/flann/include/opencv2/flann/kmeans_index.h b/modules/flann/include/opencv2/flann/kmeans_index.h index 489ed805658f35b2548fa93c957de9f8816a76c2..8bcb63d72f49dfbc4eaa8f90e7f1f02e977646e4 100644 --- a/modules/flann/include/opencv2/flann/kmeans_index.h +++ b/modules/flann/include/opencv2/flann/kmeans_index.h @@ -211,6 +211,7 @@ public: for (int i = 0; i < n; i++) { closestDistSq[i] = distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols); + closestDistSq[i] = ensureSquareDistance( closestDistSq[i] ); currentPot += closestDistSq[i]; } @@ -236,7 +237,10 @@ public: // Compute the new potential double newPot = 0; - for (int i = 0; i < n; i++) newPot += std::min( distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols), closestDistSq[i] ); + for (int i = 0; i < n; i++) { + DistanceType dist = distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols); + newPot += std::min( ensureSquareDistance(dist), closestDistSq[i] ); + } // Store the best result if ((bestNewPot < 0)||(newPot < bestNewPot)) { @@ -248,7 +252,10 @@ public: // Add the appropriate center centers[centerCount] = indices[bestNewIndex]; currentPot = bestNewPot; - for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance_(dataset_[indices[i]], dataset_[indices[bestNewIndex]], dataset_.cols), closestDistSq[i] ); + for (int i = 0; i < n; i++) { + DistanceType dist = distance_(dataset_[indices[i]], dataset_[indices[bestNewIndex]], dataset_.cols); + closestDistSq[i] = std::min( ensureSquareDistance(dist), closestDistSq[i] ); + } } centers_length = centerCount;