提交 900b3e97 编写于 作者: T tangwei12

send mpi and rpc framework

上级 9d256dd1
...@@ -12,81 +12,79 @@ ...@@ -12,81 +12,79 @@
#define mpi_tag = 2008 #define mpi_tag = 2008
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
MPIUtils::MPIUtils(const std::string& worker_name) { MPIUtils::MPIUtils(const std::string &worker_name) {
InitMPI(); InitMPI();
int rank = 0, size = 1; int rank = 0, size = 1;
char my_name[max_work_group_size]; char my_name[max_work_group_size];
MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_size(MPI_COMM_WORLD, &size);
snprintf(my_name, max_worker_name_length, worker_name.c_str()); snprintf(my_name, max_worker_name_length, worker_name.c_str());
std::vector<char> worker_names(size * max_worker_name_length); std::vector<char> worker_names(size * max_worker_name_length);
MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0], MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0],
max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD); max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD);
for (int i = 0; i < number_of_procs; i++) { for (int i = 0; i < number_of_procs; i++) {
name_to_id_[std::string(&worker_names[i * 128])] = i; name_to_id_[std::string(&worker_names[i * 128])] = i;
} }
} }
void MPIUtils::InitMPI() { void MPIUtils::InitMPI() {
int flag = 0; int flag = 0;
MPI_CHECK(MPI_Initialized(&flag)); MPI_CHECK(MPI_Initialized(&flag));
if (!flag) { if (!flag) {
int rank = 0, size = 1, len = -1; int rank = 0, size = 1, len = -1;
char host_name[max_worker_name_length]; char host_name[max_worker_name_length];
MPI_Init(0, 0); MPI_Init(0, 0);
MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Get_processor_name(host_name, &len) MPI_Get_processor_name(host_name, &len);
} }
}; };
MPIIsend::MPIIsend(int dst, const char* req) {
done1 = 0;
done2 = 0;
length = strlen(req);
req = req;
}
MPIIsend::Send() { MPISend::MPISend(const Meta &meta) {
MPI_Isend(&req, length, MPI_CHAR, dst, mpi_tag, MPI_COMM_WORLD, done1_ = 1;
&msg1_); done2_ = 0;
MPI_Test(&msg1_, &done1_, MPI_STATUS_IGNORE) this->meta = meta;
} }
bool MPIIsend::IsFinished() { MPISend::Send() {
MPI_Status status; MPI_Send(&meta.request, meta.count, meta.datatype, meta.dst, meta.tag,
if (!done1_) MPI_Test(&msg1_, &done1_, &status); MPI_COMM_WORLD);
return done1; done2_ = 1;
} }
MPIIsend::~MPIIsend(){ bool MPISend::IsReady() {
MPI_Wait(&msg1_, MPI_STATUS_IGNORE); return true;
MPI_Free_mem(req); }
}
MPIIrecv::MPIIrecv(){ bool MPISend::IsFinished() { return done1_ && done2_; }
} MPISend::~MPISend() { MPI_Free_mem(meta); }
MPIIrecv::Recv(){
} MPIRecv::MPIRecv(const Meta &meta) {
this->meta = meta;
}
MPIIrecv::IsFinished(){ MPIRecv::Recv() {}
} bool MPIRecv::IsReady() {
return true;
}
MPIIrecv::~MPIIrecv(){ MPIRecv::IsFinished() {}
} MPIRecv::~MPIRecv() {
MPI_Free_mem(meta);
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
\ No newline at end of file
...@@ -10,46 +10,74 @@ See the License for the specific language governing permissions and ...@@ -10,46 +10,74 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <mpi.h> #include <mpi.h>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
class MPIUtils { class MPIUtils {
public: public:
MPIUtils(const std::string& worker_name); MPIUtils(const std::string &worker_name);
const int GetRankID(const std::string& task_id);
const int GetRankID(const std::string &task_id);
private:
void InitMPI(); private:
std::map<std::string, int> name_id_map; void InitMPI();
};
std::map<std::string, int> name_id_map;
class MPIIsend { };
public:
MPIIsend(int dst, const char* buf); class Meta {
bool IsFinished(); public:
void Send(); int src;
~MPIIsend(); int dst;
MPI_Datatype datatype;
private: char *request;
int done1; int count;
int length; int tag;
char* req; int device;
MPI_Request msg1_; };
};
class MPISend {
class MPIIrecv { public:
public: MPISend(const Meta &meta);
MPIIrecv();
bool IsFinished(); bool IsFinished();
void Recv();
~MPIIrecv(); bool IsReady();
};
void Send();
} // namespace detail
} // namespace operators ~MPISend();
private:
int done1_;
int done2_;
Meta *meta;
};
class MPIRecv {
public:
MPIRecv(const Meta &meta);
bool IsReady();
bool IsFinished();
void Recv();
~MPIRecv();
private:
int done1_;
int done2_;
Meta *meta;
};
} // namespace detail
} // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册