flann_mpi_client.cpp 2.8 KB
Newer Older
M
Marius Muja 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
#include <stdio.h>
#include <time.h>

#include <cstdlib>
#include <iostream>
#include <flann/util/params.h>
#include <flann/io/hdf5.h>
#include <boost/asio.hpp>
#include "queries.h"

#define IF_RANK0 if (world.rank()==0)

clock_t start_time_;

void start_timer(const std::string& message = "")
{
	if (!message.empty()) {
		printf("%s", message.c_str());
		fflush(stdout);
	}
	start_time_ = clock();
}

double stop_timer()
{
	return double(clock()-start_time_)/CLOCKS_PER_SEC;
}

float compute_precision(const flann::Matrix<int>& match, const flann::Matrix<int>& indices)
{
	int count = 0;

	assert(match.rows == indices.rows);
	size_t nn = std::min(match.cols, indices.cols);

	for(size_t i=0; i<match.rows; ++i) {
		for (size_t j=0;j<nn;++j) {
			for (size_t k=0;k<nn;++k) {
				if (match[i][j]==indices[i][k]) {
					count ++;
				}
			}
		}
	}

	return float(count)/(nn*match.rows);
}


namespace flann {

template<typename ElementType, typename DistanceType>
class ClientIndex
{
public:
	ClientIndex(const std::string& host, const std::string& service)
	{
	    tcp::resolver resolver(io_service_);
	    tcp::resolver::query query(tcp::v4(), host, service);
	    iterator_ = resolver.resolve(query);
	}


	void knnSearch(const flann::Matrix<ElementType>& queries, flann::Matrix<int>& indices, flann::Matrix<DistanceType>& dists, int knn, const SearchParams& params)
	{
	    tcp::socket sock(io_service_);
	    sock.connect(*iterator_);

	    Request<ElementType> req;
	    req.nn = knn;
	    req.queries = queries;
	    // send request
	    write_object(sock,req);

	    Response<DistanceType> resp;
	    // read response
	    read_object(sock, resp);

	    for (size_t i=0;i<indices.rows;++i) {
	    	for (size_t j=0;j<indices.cols;++j) {
	    		indices[i][j] = resp.indices[i][j];
	    		dists[i][j] = resp.dists[i][j];
	    	}
	    }
	}


private:
	boost::asio::io_service io_service_;
	tcp::resolver::iterator iterator_;
};


}


using boost::asio::ip::tcp;


int main(int argc, char* argv[])
{
	try {

		flann::Matrix<float> query;
		flann::Matrix<int> match;

		flann::load_from_file(query, "sift100K.h5","query");
		flann::load_from_file(match, "sift100K.h5","match");
		//	flann::load_from_file(gt_dists, "sift100K.h5","dists");

		flann::ClientIndex<float, float> index("localhost","9999");

		int nn = 1;
		flann::Matrix<int> indices(new int[query.rows*nn], query.rows, nn);
		flann::Matrix<float> dists(new float[query.rows*nn], query.rows, nn);

		start_timer("Performing search...\n");
		index.knnSearch(query, indices, dists, nn, flann::SearchParams(128));
		printf("Search done (%g seconds)\n", stop_timer());

		printf("Checking results\n");
		float precision = compute_precision(match, indices);
		printf("Precision is: %g\n", precision);

	}
	catch (std::exception& e) {
		std::cerr << "Exception: " << e.what() << "\n";
	}

	return 0;
}