From 900b3e97de99c7a2fb493f577d554032fd6331d4 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 3 Apr 2018 13:05:49 +0800 Subject: [PATCH] send mpi and rpc framework --- paddle/fluid/operators/detail/mpi_utils.cpp | 114 ++++++++++---------- paddle/fluid/operators/detail/mpi_utils.h | 100 ++++++++++------- 2 files changed, 120 insertions(+), 94 deletions(-) diff --git a/paddle/fluid/operators/detail/mpi_utils.cpp b/paddle/fluid/operators/detail/mpi_utils.cpp index 370294fe213..d3191c15631 100644 --- a/paddle/fluid/operators/detail/mpi_utils.cpp +++ b/paddle/fluid/operators/detail/mpi_utils.cpp @@ -12,81 +12,79 @@ #define mpi_tag = 2008 namespace paddle { -namespace operators { -namespace detail { -MPIUtils::MPIUtils(const std::string& worker_name) { - InitMPI(); + namespace operators { + namespace detail { + MPIUtils::MPIUtils(const std::string &worker_name) { + InitMPI(); - int rank = 0, size = 1; - char my_name[max_work_group_size]; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); - snprintf(my_name, max_worker_name_length, worker_name.c_str()); + int rank = 0, size = 1; + char my_name[max_work_group_size]; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + snprintf(my_name, max_worker_name_length, worker_name.c_str()); - std::vector worker_names(size * max_worker_name_length); - MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0], - max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD); - for (int i = 0; i < number_of_procs; i++) { - name_to_id_[std::string(&worker_names[i * 128])] = i; - } -} + std::vector worker_names(size * max_worker_name_length); + MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0], + max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD); + for (int i = 0; i < number_of_procs; i++) { + name_to_id_[std::string(&worker_names[i * 128])] = i; + } + } -void MPIUtils::InitMPI() { - int flag = 0; - MPI_CHECK(MPI_Initialized(&flag)); + void MPIUtils::InitMPI() { + int flag = 0; + MPI_CHECK(MPI_Initialized(&flag)); - if (!flag) { - int rank = 0, size = 1, len = -1; - char host_name[max_worker_name_length]; + if (!flag) { + int rank = 0, size = 1, len = -1; + char host_name[max_worker_name_length]; - MPI_Init(0, 0); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); - MPI_Get_processor_name(host_name, &len) - } -}; + MPI_Init(0, 0); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Get_processor_name(host_name, &len); + } + }; -MPIIsend::MPIIsend(int dst, const char* req) { - done1 = 0; - done2 = 0; - length = strlen(req); - req = req; -} -MPIIsend::Send() { - MPI_Isend(&req, length, MPI_CHAR, dst, mpi_tag, MPI_COMM_WORLD, - &msg1_); - MPI_Test(&msg1_, &done1_, MPI_STATUS_IGNORE) -} + MPISend::MPISend(const Meta &meta) { + done1_ = 1; + done2_ = 0; + this->meta = meta; + } - bool MPIIsend::IsFinished() { - MPI_Status status; - if (!done1_) MPI_Test(&msg1_, &done1_, &status); - return done1; - } + MPISend::Send() { + MPI_Send(&meta.request, meta.count, meta.datatype, meta.dst, meta.tag, + MPI_COMM_WORLD); + done2_ = 1; + } -MPIIsend::~MPIIsend(){ - MPI_Wait(&msg1_, MPI_STATUS_IGNORE); - MPI_Free_mem(req); -} + bool MPISend::IsReady() { + return true; + } -MPIIrecv::MPIIrecv(){ + bool MPISend::IsFinished() { return done1_ && done2_; } -} + MPISend::~MPISend() { MPI_Free_mem(meta); } -MPIIrecv::Recv(){ -} + MPIRecv::MPIRecv(const Meta &meta) { + this->meta = meta; + } -MPIIrecv::IsFinished(){ + MPIRecv::Recv() {} -} + bool MPIRecv::IsReady() { + return true; + } -MPIIrecv::~MPIIrecv(){ + MPIRecv::IsFinished() {} -} + MPIRecv::~MPIRecv() { + MPI_Free_mem(meta); + } -} // namespace detail + } // namespace detail -} // namespace operators + } // namespace operators } // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/detail/mpi_utils.h b/paddle/fluid/operators/detail/mpi_utils.h index a754439c268..05801020bf3 100644 --- a/paddle/fluid/operators/detail/mpi_utils.h +++ b/paddle/fluid/operators/detail/mpi_utils.h @@ -10,46 +10,74 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include #include #include namespace paddle { -namespace operators { -namespace detail { -class MPIUtils { - public: - MPIUtils(const std::string& worker_name); - const int GetRankID(const std::string& task_id); - - private: - void InitMPI(); - std::map name_id_map; -}; - -class MPIIsend { - public: - MPIIsend(int dst, const char* buf); - bool IsFinished(); - void Send(); - ~MPIIsend(); - - private: - int done1; - int length; - char* req; - MPI_Request msg1_; -}; - -class MPIIrecv { - public: -MPIIrecv(); -bool IsFinished(); - void Recv(); - ~MPIIrecv(); -}; - -} // namespace detail -} // namespace operators + namespace operators { + namespace detail { + class MPIUtils { + public: + MPIUtils(const std::string &worker_name); + + const int GetRankID(const std::string &task_id); + + private: + void InitMPI(); + + std::map name_id_map; + }; + + class Meta { + public: + int src; + int dst; + MPI_Datatype datatype; + char *request; + int count; + int tag; + int device; + }; + + class MPISend { + public: + MPISend(const Meta &meta); + + bool IsFinished(); + + bool IsReady(); + + void Send(); + + ~MPISend(); + + private: + int done1_; + int done2_; + Meta *meta; + }; + + class MPIRecv { + public: + MPIRecv(const Meta &meta); + + bool IsReady(); + + bool IsFinished(); + + void Recv(); + + ~MPIRecv(); + + private: + int done1_; + int done2_; + Meta *meta; + }; + + } // namespace detail + } // namespace operators } // namespace paddle -- GitLab