diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f2a415583eea209c5254ea7a93367664ca5e5589..a25b174d08beb1453be5d3770ebc776b49445b80 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -24,9 +24,18 @@ if (HDF5_FOUND) if (USE_MPI AND HDF5_IS_PARALLEL) + find_package(Boost COMPONENTS mpi serialization system thread REQUIRED) + include_directories(${Boost_INCLUDE_DIRS}) add_definitions("-DHAVE_MPI") + add_executable(flann_example_mpi flann_example_mpi.cpp) - target_link_libraries(flann_example_mpi flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} boost_mpi) + target_link_libraries(flann_example_mpi flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} ${Boost_LIBRARIES}) + + add_executable(flann_mpi_server flann_mpi_server.cpp) + target_link_libraries(flann_mpi_server flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} ${Boost_LIBRARIES}) + + add_executable(flann_mpi_client flann_mpi_client.cpp) + target_link_libraries(flann_mpi_client flann_cpp ${HDF5_LIBRARIES} ${MPI_LIBRARIES} ${Boost_LIBRARIES}) add_dependencies(examples flann_example_mpi) install (TARGETS flann_example_mpi DESTINATION bin) diff --git a/examples/flann_example_mpi.cpp b/examples/flann_example_mpi.cpp index 0094a697b6124c653d27ba13a0de43659a06b01a..88091c1bb7e362f5ae691995bcfa2f0c36887d05 100644 --- a/examples/flann_example_mpi.cpp +++ b/examples/flann_example_mpi.cpp @@ -1,9 +1,8 @@ #include -#include -#include #include +#include #define IF_RANK0 if (world.rank()==0) @@ -44,53 +43,102 @@ float compute_precision(const flann::Matrix& match, const flann::Matrix +void serialize(Archive & ar, flann::Matrix & matrix, const unsigned int version) +{ + ar & matrix.rows & matrix.cols & matrix.stride; + if (Archive::is_loading::value) { + matrix.data = new T[matrix.rows*matrix.cols]; + } + ar & boost::serialization::make_array(matrix.data, matrix.rows*matrix.cols); +} + +} +} + + + +void search(flann::mpi::Index >* index) { - boost::mpi::environment env(argc, argv); boost::mpi::communicator world; int nn = 1; - //flann::Matrix dataset; flann::Matrix query; - flann::Matrix match; - flann::Matrix gt_dists; + flann::Matrix match; + // flann::Matrix gt_dists; + IF_RANK0 { + flann::load_from_file(query, "sift100K.h5","query"); + flann::load_from_file(match, "sift100K.h5","match"); + // flann::load_from_file(gt_dists, "sift100K.h5","dists"); + } - IF_RANK0 start_timer("Loading data...\n"); - //flann::load_from_file(dataset, "sift100K.h5","dataset"); - flann::load_from_file(query, "sift100K.h5","query"); - flann::load_from_file(match, "sift100K.h5","match"); - flann::load_from_file(gt_dists, "sift100K.h5","dists"); - flann::Matrix indices(new int[query.rows*nn], query.rows, nn); - flann::Matrix dists(new float[query.rows*nn], query.rows, nn); + boost::mpi::broadcast(world, query, 0); + boost::mpi::broadcast(world, match, 0); + + flann::Matrix indices(new int[query.rows*nn], query.rows, nn); + flann::Matrix dists(new float[query.rows*nn], query.rows, nn); + + IF_RANK0 { + indices = flann::Matrix(new int[query.rows*nn], query.rows, nn); + dists = flann::Matrix(new float[query.rows*nn], query.rows, nn); + } + + // do a knn search, using 128 checks0 + IF_RANK0 start_timer("Performing search...\n"); + index->knnSearch(query, indices, dists, nn, flann::SearchParams(128)); + IF_RANK0 { + printf("Search done (%g seconds)\n", stop_timer()); + printf("Indices size: (%d,%d)\n", (int)indices.rows, (int)indices.cols); + printf("Checking results\n"); + float precision = compute_precision(match, indices); + printf("Precision is: %g\n", precision); + } + delete[] query.data; + delete[] match.data; + + IF_RANK0 { + delete[] indices.data; + delete[] dists.data; + } + +} + + +int main(int argc, char** argv) +{ + boost::mpi::environment env(argc, argv); + boost::mpi::communicator world; + + //flann::Matrix dataset; + + IF_RANK0 start_timer("Loading data...\n"); // construct an randomized kd-tree index using 4 kd-trees flann::mpi::Index > index("sift100K.h5", "dataset", flann::KDTreeIndexParams(4)); + //flann::load_from_file(dataset, "sift100K.h5","dataset"); //flann::Index > index( dataset, flann::KDTreeIndexParams(4)); + world.barrier(); IF_RANK0 printf("Loading data done (%g seconds)\n", stop_timer()); - IF_RANK0 printf("Index size: (%d,%d)\n", index.size(), index.veclen()); start_timer("Building index...\n"); - index.buildIndex(); + index.buildIndex(); printf("Building index done (%g seconds)\n", stop_timer()); - world.barrier(); - // do a knn search, using 128 checks - IF_RANK0 start_timer("Performing search...\n"); - index.knnSearch(query, indices, dists, nn, flann::SearchParams(128)); - IF_RANK0 printf("Search done (%g seconds)\n", stop_timer()); - IF_RANK0 { - printf("Indices size: (%d,%d)\n", (int)indices.rows, (int)indices.cols); - printf("Checking results\n"); - float precision = compute_precision(match, indices); - printf("Precision is: %g\n", precision); - } - delete[] query.data; - delete[] indices.data; - delete[] dists.data; - delete[] match.data; + printf("Searching...\n"); + + + boost::thread t(boost::bind(search, &index)); + t.join(); + boost::thread t2(boost::bind(search, &index)); + + + for(;;){}; return 0; } diff --git a/examples/flann_mpi_client.cpp b/examples/flann_mpi_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b8196994021ef79eb55a1f5389b354c92637004 --- /dev/null +++ b/examples/flann_mpi_client.cpp @@ -0,0 +1,132 @@ +#include +#include + +#include +#include +#include +#include +#include +#include "queries.h" + +#define IF_RANK0 if (world.rank()==0) + +clock_t start_time_; + +void start_timer(const std::string& message = "") +{ + if (!message.empty()) { + printf("%s", message.c_str()); + fflush(stdout); + } + start_time_ = clock(); +} + +double stop_timer() +{ + return double(clock()-start_time_)/CLOCKS_PER_SEC; +} + +float compute_precision(const flann::Matrix& match, const flann::Matrix& indices) +{ + int count = 0; + + assert(match.rows == indices.rows); + size_t nn = std::min(match.cols, indices.cols); + + for(size_t i=0; i +class ClientIndex +{ +public: + ClientIndex(const std::string& host, const std::string& service) + { + tcp::resolver resolver(io_service_); + tcp::resolver::query query(tcp::v4(), host, service); + iterator_ = resolver.resolve(query); + } + + + void knnSearch(const flann::Matrix& queries, flann::Matrix& indices, flann::Matrix& dists, int knn, const SearchParams& params) + { + tcp::socket sock(io_service_); + sock.connect(*iterator_); + + Request req; + req.nn = knn; + req.queries = queries; + // send request + write_object(sock,req); + + Response resp; + // read response + read_object(sock, resp); + + for (size_t i=0;i query; + flann::Matrix match; + + flann::load_from_file(query, "sift100K.h5","query"); + flann::load_from_file(match, "sift100K.h5","match"); + // flann::load_from_file(gt_dists, "sift100K.h5","dists"); + + flann::ClientIndex index("localhost","9999"); + + int nn = 1; + flann::Matrix indices(new int[query.rows*nn], query.rows, nn); + flann::Matrix dists(new float[query.rows*nn], query.rows, nn); + + start_timer("Performing search...\n"); + index.knnSearch(query, indices, dists, nn, flann::SearchParams(128)); + printf("Search done (%g seconds)\n", stop_timer()); + + printf("Checking results\n"); + float precision = compute_precision(match, indices); + printf("Precision is: %g\n", precision); + + } + catch (std::exception& e) { + std::cerr << "Exception: " << e.what() << "\n"; + } + + return 0; +} + diff --git a/examples/flann_mpi_server.cpp b/examples/flann_mpi_server.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b2ab2a3ff8e19cf2f6cc47bfe0e5af9706825db --- /dev/null +++ b/examples/flann_mpi_server.cpp @@ -0,0 +1,138 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "queries.h" + +namespace flann { + +template +class MPIServer +{ + + typedef typename Distance::ElementType ElementType; + typedef typename Distance::ResultType DistanceType; + typedef boost::shared_ptr socket_ptr; + typedef flann::mpi::Index FlannIndex; + + void session(socket_ptr sock) + { + boost::mpi::communicator world; + try { + Request req; + if (world.rank()==0) { + read_object(*sock,req); + std::cout << "Received query\n"; + } + // broadcast request to all MPI processes + boost::mpi::broadcast(world, req, 0); + + Response resp; + if (world.rank()==0) { + int rows = req.queries.rows; + int cols = req.nn; + resp.indices = flann::Matrix(new int[rows*cols], rows, cols); + resp.dists = flann::Matrix(new DistanceType[rows*cols], rows, cols); + } + + std::cout << "Searching in process " << world.rank() << "\n"; + index_->knnSearch(req.queries, resp.indices, resp.dists, req.nn, flann::SearchParams(128)); + + if (world.rank()==0) { + std::cout << "Sending result\n"; + write_object(*sock,resp); + } + + delete[] req.queries.data; + if (world.rank()==0) { + delete[] resp.indices.data; + delete[] resp.dists.data; + } + + } + catch (std::exception& e) { + std::cerr << "Exception in thread: " << e.what() << "\n"; + } + } + + + +public: + MPIServer(const std::string& filename, const std::string& dataset, short port) : + port_(port) + { + boost::mpi::communicator world; + if (world.rank()==0) { + std::cout << "Reading dataset and building index..."; + std::flush(std::cout); + } + index_ = new FlannIndex(filename, dataset, flann::KDTreeIndexParams(4)); + index_->buildIndex(); + world.barrier(); // wait for data to be loaded and indexes to be created + if (world.rank()==0) { + std::cout << "done.\n"; + } + } + + + void run() + { + boost::mpi::communicator world; + boost::shared_ptr io_service; + boost::shared_ptr acceptor; + + if (world.rank()==0) { + io_service.reset(new boost::asio::io_service()); + acceptor.reset(new tcp::acceptor(*io_service, tcp::endpoint(tcp::v4(), port_))); + std::cout << "Start listening for queries...\n"; + } + for (;;) { + socket_ptr sock; + if (world.rank()==0) { + sock.reset(new tcp::socket(*io_service)); + acceptor->accept(*sock); + std::cout << "Accepted connection\n"; + } + world.barrier(); // everybody waits here for a connection + boost::thread t(boost::bind(&MPIServer::session, this, sock)); + t.join(); + } + + } + +private: + FlannIndex* index_; + short port_; +}; + + +} + + +int main(int argc, char* argv[]) +{ + boost::mpi::environment env(argc, argv); + + try { + if (argc != 4) { + std::cout << "Usage: " << argv[0] << " \n"; + return 1; + } + flann::MPIServer > server(argv[1], argv[2], std::atoi(argv[3])); + + server.run(); + } + catch (std::exception& e) { + std::cerr << "Exception: " << e.what() << "\n"; + } + + return 0; +} + diff --git a/examples/queries.h b/examples/queries.h new file mode 100644 index 0000000000000000000000000000000000000000..16258ac6b074147d5a8b566b778f2823c0dd1b62 --- /dev/null +++ b/examples/queries.h @@ -0,0 +1,118 @@ +/*********************************************************************** + * Software License Agreement (BSD License) + * + * Copyright 2008-2010 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. + * Copyright 2008-2010 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *************************************************************************/ + + +#ifndef QUERIES_H_ +#define QUERIES_H_ + +#include +#include +#include + + +namespace boost { +namespace serialization { + +template +void serialize(Archive & ar, flann::Matrix & matrix, const unsigned int version) +{ + ar & matrix.rows & matrix.cols & matrix.stride; + if (Archive::is_loading::value) { + matrix.data = new T[matrix.rows*matrix.cols]; + } + ar & boost::serialization::make_array(matrix.data, matrix.rows*matrix.cols); +} + +} +} + +namespace flann +{ + +template +struct Request +{ + flann::Matrix queries; + int nn; + + template + void serialize(Archive& ar, const unsigned int version) + { + ar & queries & nn; + } +}; + +template +struct Response +{ + flann::Matrix indices; + flann::Matrix dists; + + template + void serialize(Archive& ar, const unsigned int version) + { + ar & indices & dists; + } +}; + + +using boost::asio::ip::tcp; + +template +void read_object(tcp::socket& sock, T& val) +{ + uint32_t size; + boost::asio::read(sock, boost::asio::buffer(&size, sizeof(size))); + size = ntohl(size); + + boost::asio::streambuf archive_stream; + boost::asio::read(sock, archive_stream, boost::asio::transfer_at_least(size)); + + boost::archive::binary_iarchive archive(archive_stream); + archive >> val; +} + +template +void write_object(tcp::socket& sock, const T& val) +{ + boost::asio::streambuf archive_stream; + boost::archive::binary_oarchive archive(archive_stream); + archive << val; + + uint32_t size = archive_stream.size(); + size = htonl(size); + boost::asio::write(sock, boost::asio::buffer(&size, sizeof(size))); + boost::asio::write(sock, archive_stream); + +} + +} + + + +#endif /* QUERIES_H_ */