提交 decfc7a5 编写于 作者: M Marius Muja

MPI client/server

上级 72de077f
......@@ -24,9 +24,18 @@ if (HDF5_FOUND)
if (USE_MPI AND HDF5_IS_PARALLEL)
find_package(Boost COMPONENTS mpi serialization system thread REQUIRED)
include_directories(${Boost_INCLUDE_DIRS})
add_definitions("-DHAVE_MPI")
add_executable(flann_example_mpi flann_example_mpi.cpp)
target_link_libraries(flann_example_mpi flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} boost_mpi)
target_link_libraries(flann_example_mpi flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} ${Boost_LIBRARIES})
add_executable(flann_mpi_server flann_mpi_server.cpp)
target_link_libraries(flann_mpi_server flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} ${Boost_LIBRARIES})
add_executable(flann_mpi_client flann_mpi_client.cpp)
target_link_libraries(flann_mpi_client flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} ${Boost_LIBRARIES})
add_dependencies(examples flann_example_mpi)
install (TARGETS flann_example_mpi DESTINATION bin)
......
#include <flann/flann_mpi.hpp>
#include <flann/flann.hpp>
#include <flann/io/hdf5.h>
#include <stdio.h>
#include <boost/thread/thread.hpp>
#define IF_RANK0 if (world.rank()==0)
......@@ -44,53 +43,102 @@ float compute_precision(const flann::Matrix<int>& match, const flann::Matrix<int
}
int main(int argc, char** argv)
namespace boost {
namespace serialization {
template<class Archive, class T>
void serialize(Archive & ar, flann::Matrix<T> & matrix, const unsigned int version)
{
ar & matrix.rows & matrix.cols & matrix.stride;
if (Archive::is_loading::value) {
matrix.data = new T[matrix.rows*matrix.cols];
}
ar & boost::serialization::make_array(matrix.data, matrix.rows*matrix.cols);
}
}
}
void search(flann::mpi::Index<flann::L2<float> >* index)
{
boost::mpi::environment env(argc, argv);
boost::mpi::communicator world;
int nn = 1;
//flann::Matrix<float> dataset;
flann::Matrix<float> query;
flann::Matrix<int> match;
flann::Matrix<float> gt_dists;
flann::Matrix<int> match;
// flann::Matrix<float> gt_dists;
IF_RANK0 {
flann::load_from_file(query, "sift100K.h5","query");
flann::load_from_file(match, "sift100K.h5","match");
// flann::load_from_file(gt_dists, "sift100K.h5","dists");
}
IF_RANK0 start_timer("Loading data...\n");
//flann::load_from_file(dataset, "sift100K.h5","dataset");
flann::load_from_file(query, "sift100K.h5","query");
flann::load_from_file(match, "sift100K.h5","match");
flann::load_from_file(gt_dists, "sift100K.h5","dists");
flann::Matrix<int> indices(new int[query.rows*nn], query.rows, nn);
flann::Matrix<float> dists(new float[query.rows*nn], query.rows, nn);
boost::mpi::broadcast(world, query, 0);
boost::mpi::broadcast(world, match, 0);
flann::Matrix<int> indices(new int[query.rows*nn], query.rows, nn);
flann::Matrix<float> dists(new float[query.rows*nn], query.rows, nn);
IF_RANK0 {
indices = flann::Matrix<int>(new int[query.rows*nn], query.rows, nn);
dists = flann::Matrix<float>(new float[query.rows*nn], query.rows, nn);
}
// do a knn search, using 128 checks0
IF_RANK0 start_timer("Performing search...\n");
index->knnSearch(query, indices, dists, nn, flann::SearchParams(128));
IF_RANK0 {
printf("Search done (%g seconds)\n", stop_timer());
printf("Indices size: (%d,%d)\n", (int)indices.rows, (int)indices.cols);
printf("Checking results\n");
float precision = compute_precision(match, indices);
printf("Precision is: %g\n", precision);
}
delete[] query.data;
delete[] match.data;
IF_RANK0 {
delete[] indices.data;
delete[] dists.data;
}
}
int main(int argc, char** argv)
{
boost::mpi::environment env(argc, argv);
boost::mpi::communicator world;
//flann::Matrix<float> dataset;
IF_RANK0 start_timer("Loading data...\n");
// construct an randomized kd-tree index using 4 kd-trees
flann::mpi::Index<flann::L2<float> > index("sift100K.h5", "dataset", flann::KDTreeIndexParams(4));
//flann::load_from_file(dataset, "sift100K.h5","dataset");
//flann::Index<flann::L2<float> > index( dataset, flann::KDTreeIndexParams(4));
world.barrier();
IF_RANK0 printf("Loading data done (%g seconds)\n", stop_timer());
IF_RANK0 printf("Index size: (%d,%d)\n", index.size(), index.veclen());
start_timer("Building index...\n");
index.buildIndex();
index.buildIndex();
printf("Building index done (%g seconds)\n", stop_timer());
world.barrier();
// do a knn search, using 128 checks
IF_RANK0 start_timer("Performing search...\n");
index.knnSearch(query, indices, dists, nn, flann::SearchParams(128));
IF_RANK0 printf("Search done (%g seconds)\n", stop_timer());
IF_RANK0 {
printf("Indices size: (%d,%d)\n", (int)indices.rows, (int)indices.cols);
printf("Checking results\n");
float precision = compute_precision(match, indices);
printf("Precision is: %g\n", precision);
}
delete[] query.data;
delete[] indices.data;
delete[] dists.data;
delete[] match.data;
printf("Searching...\n");
boost::thread t(boost::bind(search, &index));
t.join();
boost::thread t2(boost::bind(search, &index));
for(;;){};
return 0;
}
#include <stdio.h>
#include <time.h>
#include <cstdlib>
#include <iostream>
#include <flann/util/params.h>
#include <flann/io/hdf5.h>
#include <boost/asio.hpp>
#include "queries.h"
#define IF_RANK0 if (world.rank()==0)
clock_t start_time_;
void start_timer(const std::string& message = "")
{
if (!message.empty()) {
printf("%s", message.c_str());
fflush(stdout);
}
start_time_ = clock();
}
double stop_timer()
{
return double(clock()-start_time_)/CLOCKS_PER_SEC;
}
float compute_precision(const flann::Matrix<int>& match, const flann::Matrix<int>& indices)
{
int count = 0;
assert(match.rows == indices.rows);
size_t nn = std::min(match.cols, indices.cols);
for(size_t i=0; i<match.rows; ++i) {
for (size_t j=0;j<nn;++j) {
for (size_t k=0;k<nn;++k) {
if (match[i][j]==indices[i][k]) {
count ++;
}
}
}
}
return float(count)/(nn*match.rows);
}
namespace flann {
template<typename ElementType, typename DistanceType>
class ClientIndex
{
public:
ClientIndex(const std::string& host, const std::string& service)
{
tcp::resolver resolver(io_service_);
tcp::resolver::query query(tcp::v4(), host, service);
iterator_ = resolver.resolve(query);
}
void knnSearch(const flann::Matrix<ElementType>& queries, flann::Matrix<int>& indices, flann::Matrix<DistanceType>& dists, int knn, const SearchParams& params)
{
tcp::socket sock(io_service_);
sock.connect(*iterator_);
Request<ElementType> req;
req.nn = knn;
req.queries = queries;
// send request
write_object(sock,req);
Response<DistanceType> resp;
// read response
read_object(sock, resp);
for (size_t i=0;i<indices.rows;++i) {
for (size_t j=0;j<indices.cols;++j) {
indices[i][j] = resp.indices[i][j];
dists[i][j] = resp.dists[i][j];
}
}
}
private:
boost::asio::io_service io_service_;
tcp::resolver::iterator iterator_;
};
}
using boost::asio::ip::tcp;
int main(int argc, char* argv[])
{
try {
flann::Matrix<float> query;
flann::Matrix<int> match;
flann::load_from_file(query, "sift100K.h5","query");
flann::load_from_file(match, "sift100K.h5","match");
// flann::load_from_file(gt_dists, "sift100K.h5","dists");
flann::ClientIndex<float, float> index("localhost","9999");
int nn = 1;
flann::Matrix<int> indices(new int[query.rows*nn], query.rows, nn);
flann::Matrix<float> dists(new float[query.rows*nn], query.rows, nn);
start_timer("Performing search...\n");
index.knnSearch(query, indices, dists, nn, flann::SearchParams(128));
printf("Search done (%g seconds)\n", stop_timer());
printf("Checking results\n");
float precision = compute_precision(match, indices);
printf("Precision is: %g\n", precision);
}
catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << "\n";
}
return 0;
}
#include <flann/flann_mpi.hpp>
#include <stdio.h>
#include <time.h>
#include <cstdlib>
#include <iostream>
#include <boost/bind.hpp>
#include <boost/smart_ptr.hpp>
#include <boost/asio.hpp>
#include <boost/thread/thread.hpp>
#include "queries.h"
namespace flann {
template<typename Distance>
class MPIServer
{
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
typedef boost::shared_ptr<tcp::socket> socket_ptr;
typedef flann::mpi::Index<Distance> FlannIndex;
void session(socket_ptr sock)
{
boost::mpi::communicator world;
try {
Request<ElementType> req;
if (world.rank()==0) {
read_object(*sock,req);
std::cout << "Received query\n";
}
// broadcast request to all MPI processes
boost::mpi::broadcast(world, req, 0);
Response<DistanceType> resp;
if (world.rank()==0) {
int rows = req.queries.rows;
int cols = req.nn;
resp.indices = flann::Matrix<int>(new int[rows*cols], rows, cols);
resp.dists = flann::Matrix<DistanceType>(new DistanceType[rows*cols], rows, cols);
}
std::cout << "Searching in process " << world.rank() << "\n";
index_->knnSearch(req.queries, resp.indices, resp.dists, req.nn, flann::SearchParams(128));
if (world.rank()==0) {
std::cout << "Sending result\n";
write_object(*sock,resp);
}
delete[] req.queries.data;
if (world.rank()==0) {
delete[] resp.indices.data;
delete[] resp.dists.data;
}
}
catch (std::exception& e) {
std::cerr << "Exception in thread: " << e.what() << "\n";
}
}
public:
MPIServer(const std::string& filename, const std::string& dataset, short port) :
port_(port)
{
boost::mpi::communicator world;
if (world.rank()==0) {
std::cout << "Reading dataset and building index...";
std::flush(std::cout);
}
index_ = new FlannIndex(filename, dataset, flann::KDTreeIndexParams(4));
index_->buildIndex();
world.barrier(); // wait for data to be loaded and indexes to be created
if (world.rank()==0) {
std::cout << "done.\n";
}
}
void run()
{
boost::mpi::communicator world;
boost::shared_ptr<boost::asio::io_service> io_service;
boost::shared_ptr<tcp::acceptor> acceptor;
if (world.rank()==0) {
io_service.reset(new boost::asio::io_service());
acceptor.reset(new tcp::acceptor(*io_service, tcp::endpoint(tcp::v4(), port_)));
std::cout << "Start listening for queries...\n";
}
for (;;) {
socket_ptr sock;
if (world.rank()==0) {
sock.reset(new tcp::socket(*io_service));
acceptor->accept(*sock);
std::cout << "Accepted connection\n";
}
world.barrier(); // everybody waits here for a connection
boost::thread t(boost::bind(&MPIServer::session, this, sock));
t.join();
}
}
private:
FlannIndex* index_;
short port_;
};
}
int main(int argc, char* argv[])
{
boost::mpi::environment env(argc, argv);
try {
if (argc != 4) {
std::cout << "Usage: " << argv[0] << " <file> <dataset> <port>\n";
return 1;
}
flann::MPIServer<flann::L2<float> > server(argv[1], argv[2], std::atoi(argv[3]));
server.run();
}
catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << "\n";
}
return 0;
}
/***********************************************************************
* Software License Agreement (BSD License)
*
* Copyright 2008-2010 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
* Copyright 2008-2010 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*************************************************************************/
#ifndef QUERIES_H_
#define QUERIES_H_
#include <flann/util/matrix.h>
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>
namespace boost {
namespace serialization {
template<class Archive, class T>
void serialize(Archive & ar, flann::Matrix<T> & matrix, const unsigned int version)
{
ar & matrix.rows & matrix.cols & matrix.stride;
if (Archive::is_loading::value) {
matrix.data = new T[matrix.rows*matrix.cols];
}
ar & boost::serialization::make_array(matrix.data, matrix.rows*matrix.cols);
}
}
}
namespace flann
{
template<typename T>
struct Request
{
flann::Matrix<T> queries;
int nn;
template<typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
ar & queries & nn;
}
};
template<typename T>
struct Response
{
flann::Matrix<int> indices;
flann::Matrix<T> dists;
template<typename Archive>
void serialize(Archive& ar, const unsigned int version)
{
ar & indices & dists;
}
};
using boost::asio::ip::tcp;
template <typename T>
void read_object(tcp::socket& sock, T& val)
{
uint32_t size;
boost::asio::read(sock, boost::asio::buffer(&size, sizeof(size)));
size = ntohl(size);
boost::asio::streambuf archive_stream;
boost::asio::read(sock, archive_stream, boost::asio::transfer_at_least(size));
boost::archive::binary_iarchive archive(archive_stream);
archive >> val;
}
template <typename T>
void write_object(tcp::socket& sock, const T& val)
{
boost::asio::streambuf archive_stream;
boost::archive::binary_oarchive archive(archive_stream);
archive << val;
uint32_t size = archive_stream.size();
size = htonl(size);
boost::asio::write(sock, boost::asio::buffer(&size, sizeof(size)));
boost::asio::write(sock, archive_stream);
}
}
#endif /* QUERIES_H_ */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册