提交 900b3e97 编写于 作者: T tangwei12

send mpi and rpc framework

上级 9d256dd1
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
#define mpi_tag = 2008 #define mpi_tag = 2008
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
MPIUtils::MPIUtils(const std::string& worker_name) { MPIUtils::MPIUtils(const std::string &worker_name) {
InitMPI(); InitMPI();
int rank = 0, size = 1; int rank = 0, size = 1;
...@@ -29,9 +29,9 @@ MPIUtils::MPIUtils(const std::string& worker_name) { ...@@ -29,9 +29,9 @@ MPIUtils::MPIUtils(const std::string& worker_name) {
for (int i = 0; i < number_of_procs; i++) { for (int i = 0; i < number_of_procs; i++) {
name_to_id_[std::string(&worker_names[i * 128])] = i; name_to_id_[std::string(&worker_names[i * 128])] = i;
} }
} }
void MPIUtils::InitMPI() { void MPIUtils::InitMPI() {
int flag = 0; int flag = 0;
MPI_CHECK(MPI_Initialized(&flag)); MPI_CHECK(MPI_Initialized(&flag));
...@@ -42,51 +42,49 @@ void MPIUtils::InitMPI() { ...@@ -42,51 +42,49 @@ void MPIUtils::InitMPI() {
MPI_Init(0, 0); MPI_Init(0, 0);
MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Get_processor_name(host_name, &len) MPI_Get_processor_name(host_name, &len);
} }
}; };
MPIIsend::MPIIsend(int dst, const char* req) {
done1 = 0; MPISend::MPISend(const Meta &meta) {
done2 = 0; done1_ = 1;
length = strlen(req); done2_ = 0;
req = req; this->meta = meta;
}
MPIIsend::Send() {
MPI_Isend(&req, length, MPI_CHAR, dst, mpi_tag, MPI_COMM_WORLD,
&msg1_);
MPI_Test(&msg1_, &done1_, MPI_STATUS_IGNORE)
}
bool MPIIsend::IsFinished() {
MPI_Status status;
if (!done1_) MPI_Test(&msg1_, &done1_, &status);
return done1;
} }
MPIIsend::~MPIIsend(){ MPISend::Send() {
MPI_Wait(&msg1_, MPI_STATUS_IGNORE); MPI_Send(&meta.request, meta.count, meta.datatype, meta.dst, meta.tag,
MPI_Free_mem(req); MPI_COMM_WORLD);
} done2_ = 1;
}
MPIIrecv::MPIIrecv(){ bool MPISend::IsReady() {
return true;
}
} bool MPISend::IsFinished() { return done1_ && done2_; }
MPIIrecv::Recv(){ MPISend::~MPISend() { MPI_Free_mem(meta); }
}
MPIIrecv::IsFinished(){ MPIRecv::MPIRecv(const Meta &meta) {
this->meta = meta;
}
MPIRecv::Recv() {}
} bool MPIRecv::IsReady() {
return true;
}
MPIIrecv::~MPIIrecv(){ MPIRecv::IsFinished() {}
} MPIRecv::~MPIRecv() {
MPI_Free_mem(meta);
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
\ No newline at end of file
...@@ -10,46 +10,74 @@ See the License for the specific language governing permissions and ...@@ -10,46 +10,74 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <mpi.h> #include <mpi.h>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
class MPIUtils { class MPIUtils {
public: public:
MPIUtils(const std::string& worker_name); MPIUtils(const std::string &worker_name);
const int GetRankID(const std::string& task_id);
const int GetRankID(const std::string &task_id);
private: private:
void InitMPI(); void InitMPI();
std::map<std::string, int> name_id_map; std::map<std::string, int> name_id_map;
}; };
class Meta {
public:
int src;
int dst;
MPI_Datatype datatype;
char *request;
int count;
int tag;
int device;
};
class MPIIsend { class MPISend {
public: public:
MPIIsend(int dst, const char* buf); MPISend(const Meta &meta);
bool IsFinished(); bool IsFinished();
bool IsReady();
void Send(); void Send();
~MPIIsend();
~MPISend();
private: private:
int done1; int done1_;
int length; int done2_;
char* req; Meta *meta;
MPI_Request msg1_; };
};
class MPIIrecv { class MPIRecv {
public: public:
MPIIrecv(); MPIRecv(const Meta &meta);
bool IsFinished();
bool IsReady();
bool IsFinished();
void Recv(); void Recv();
~MPIIrecv();
};
} // namespace detail ~MPIRecv();
} // namespace operators
private:
int done1_;
int done2_;
Meta *meta;
};
} // namespace detail
} // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册