From 94ce020241cbc081a95089e677a79e340f3a4e0a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 28 Mar 2018 15:50:59 +0800 Subject: [PATCH] mpi tools --- paddle/fluid/operators/detail/mpi_client.cpp | 29 ++++++++++++ paddle/fluid/operators/detail/mpi_client.h | 37 ++++++++-------- paddle/fluid/operators/detail/mpi_utils.cpp | 46 +++++++++++++++++++- paddle/fluid/operators/detail/mpi_utils.h | 21 ++++----- 4 files changed, 105 insertions(+), 28 deletions(-) create mode 100644 paddle/fluid/operators/detail/mpi_client.cpp diff --git a/paddle/fluid/operators/detail/mpi_client.cpp b/paddle/fluid/operators/detail/mpi_client.cpp new file mode 100644 index 00000000000..6890e437ef1 --- /dev/null +++ b/paddle/fluid/operators/detail/mpi_client.cpp @@ -0,0 +1,29 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mpi_client.h" +#include "mpi_utils.h" + +namespace paddle { +namespace operators { +namespace detail { +bool MPIClient::AsyncSendVariable() { + char* msg = "123456787654"; + int dst = 1; + MPIIsend send = MPIIsend(dst, msg); +} + +bool MPIClient::Wait() {} + +} // namespace detail +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/detail/mpi_client.h b/paddle/fluid/operators/detail/mpi_client.h index 14dcd678a09..a01e5b2d119 100644 --- a/paddle/fluid/operators/detail/mpi_client.h +++ b/paddle/fluid/operators/detail/mpi_client.h @@ -26,23 +26,26 @@ namespace operators { namespace detail { class MPIClient { public: - bool AsyncSendVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = 600 * 1000); - - bool AsyncGetVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = 600 * 1000); - - void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = 600 * 1000); - - void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = 600 * 1000); + // bool AsyncSendVariable(const std::string& ep, + // const platform::DeviceContext& ctx, + // const framework::Scope& scope, + // const std::string& var_name, + // int64_t time_out = 600 * 1000); + + // bool AsyncGetVariable(const std::string& ep, + // const platform::DeviceContext& ctx, + // const framework::Scope& scope, + // const std::string& var_name, + // int64_t time_out = 600 * 1000); + + // void AsyncSendBatchBarrier(const std::string& ep, + // int64_t time_out = 600 * 1000); + + // void AsyncSendFetchBarrier(const std::string& ep, + // int64_t time_out = 600 * 1000); + + bool AsyncSendVariable(); + bool Wait(); private: diff --git a/paddle/fluid/operators/detail/mpi_utils.cpp b/paddle/fluid/operators/detail/mpi_utils.cpp index 6560761e6b9..370294fe213 100644 --- a/paddle/fluid/operators/detail/mpi_utils.cpp +++ b/paddle/fluid/operators/detail/mpi_utils.cpp @@ -3,10 +3,13 @@ // #include +#include -#include "paddle/fluid/operators/detail/mpi_utils.h" +#include +#include "mpi_utils.h" #define max_worker_name_length 128 +#define mpi_tag = 2008 namespace paddle { namespace operators { @@ -42,6 +45,47 @@ void MPIUtils::InitMPI() { MPI_Get_processor_name(host_name, &len) } }; + +MPIIsend::MPIIsend(int dst, const char* req) { + done1 = 0; + done2 = 0; + length = strlen(req); + req = req; +} + +MPIIsend::Send() { + MPI_Isend(&req, length, MPI_CHAR, dst, mpi_tag, MPI_COMM_WORLD, + &msg1_); + MPI_Test(&msg1_, &done1_, MPI_STATUS_IGNORE) +} + + bool MPIIsend::IsFinished() { + MPI_Status status; + if (!done1_) MPI_Test(&msg1_, &done1_, &status); + return done1; + } + +MPIIsend::~MPIIsend(){ + MPI_Wait(&msg1_, MPI_STATUS_IGNORE); + MPI_Free_mem(req); +} + +MPIIrecv::MPIIrecv(){ + +} + +MPIIrecv::Recv(){ + +} + +MPIIrecv::IsFinished(){ + +} + +MPIIrecv::~MPIIrecv(){ + +} + } // namespace detail } // namespace operators diff --git a/paddle/fluid/operators/detail/mpi_utils.h b/paddle/fluid/operators/detail/mpi_utils.h index 1f5ffdb18cf..a754439c268 100644 --- a/paddle/fluid/operators/detail/mpi_utils.h +++ b/paddle/fluid/operators/detail/mpi_utils.h @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include +#include #include #include #include @@ -30,22 +30,23 @@ class MPIUtils { class MPIIsend { public: - void init(); - int isFinished(); - void send(); + MPIIsend(int dst, const char* buf); + bool IsFinished(); + void Send(); ~MPIIsend(); private: - int done1; - int done2; - sendrecv::VariableMessage req; + int done1; + int length; + char* req; + MPI_Request msg1_; }; class MPIIrecv { public: - void init(); - int isFinished(); - void recv(); +MPIIrecv(); +bool IsFinished(); + void Recv(); ~MPIIrecv(); }; -- GitLab