提交 900b3e97 编写于 作者: T tangwei12

send mpi and rpc framework

上级 9d256dd1
......@@ -12,81 +12,79 @@
#define mpi_tag = 2008
namespace paddle {
namespace operators {
namespace detail {
MPIUtils::MPIUtils(const std::string& worker_name) {
InitMPI();
namespace operators {
namespace detail {
MPIUtils::MPIUtils(const std::string &worker_name) {
InitMPI();
int rank = 0, size = 1;
char my_name[max_work_group_size];
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
snprintf(my_name, max_worker_name_length, worker_name.c_str());
int rank = 0, size = 1;
char my_name[max_work_group_size];
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
snprintf(my_name, max_worker_name_length, worker_name.c_str());
std::vector<char> worker_names(size * max_worker_name_length);
MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0],
max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD);
for (int i = 0; i < number_of_procs; i++) {
name_to_id_[std::string(&worker_names[i * 128])] = i;
}
}
std::vector<char> worker_names(size * max_worker_name_length);
MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0],
max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD);
for (int i = 0; i < number_of_procs; i++) {
name_to_id_[std::string(&worker_names[i * 128])] = i;
}
}
void MPIUtils::InitMPI() {
int flag = 0;
MPI_CHECK(MPI_Initialized(&flag));
void MPIUtils::InitMPI() {
int flag = 0;
MPI_CHECK(MPI_Initialized(&flag));
if (!flag) {
int rank = 0, size = 1, len = -1;
char host_name[max_worker_name_length];
if (!flag) {
int rank = 0, size = 1, len = -1;
char host_name[max_worker_name_length];
MPI_Init(0, 0);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Get_processor_name(host_name, &len)
}
};
MPI_Init(0, 0);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Get_processor_name(host_name, &len);
}
};
MPIIsend::MPIIsend(int dst, const char* req) {
done1 = 0;
done2 = 0;
length = strlen(req);
req = req;
}
MPIIsend::Send() {
MPI_Isend(&req, length, MPI_CHAR, dst, mpi_tag, MPI_COMM_WORLD,
&msg1_);
MPI_Test(&msg1_, &done1_, MPI_STATUS_IGNORE)
}
MPISend::MPISend(const Meta &meta) {
done1_ = 1;
done2_ = 0;
this->meta = meta;
}
bool MPIIsend::IsFinished() {
MPI_Status status;
if (!done1_) MPI_Test(&msg1_, &done1_, &status);
return done1;
}
MPISend::Send() {
MPI_Send(&meta.request, meta.count, meta.datatype, meta.dst, meta.tag,
MPI_COMM_WORLD);
done2_ = 1;
}
MPIIsend::~MPIIsend(){
MPI_Wait(&msg1_, MPI_STATUS_IGNORE);
MPI_Free_mem(req);
}
bool MPISend::IsReady() {
return true;
}
MPIIrecv::MPIIrecv(){
bool MPISend::IsFinished() { return done1_ && done2_; }
}
MPISend::~MPISend() { MPI_Free_mem(meta); }
MPIIrecv::Recv(){
}
MPIRecv::MPIRecv(const Meta &meta) {
this->meta = meta;
}
MPIIrecv::IsFinished(){
MPIRecv::Recv() {}
}
bool MPIRecv::IsReady() {
return true;
}
MPIIrecv::~MPIIrecv(){
MPIRecv::IsFinished() {}
}
MPIRecv::~MPIRecv() {
MPI_Free_mem(meta);
}
} // namespace detail
} // namespace detail
} // namespace operators
} // namespace operators
} // namespace paddle
\ No newline at end of file
......@@ -10,46 +10,74 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mpi.h>
#include <map>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
namespace detail {
class MPIUtils {
public:
MPIUtils(const std::string& worker_name);
const int GetRankID(const std::string& task_id);
private:
void InitMPI();
std::map<std::string, int> name_id_map;
};
class MPIIsend {
public:
MPIIsend(int dst, const char* buf);
bool IsFinished();
void Send();
~MPIIsend();
private:
int done1;
int length;
char* req;
MPI_Request msg1_;
};
class MPIIrecv {
public:
MPIIrecv();
bool IsFinished();
void Recv();
~MPIIrecv();
};
} // namespace detail
} // namespace operators
namespace operators {
namespace detail {
class MPIUtils {
public:
MPIUtils(const std::string &worker_name);
const int GetRankID(const std::string &task_id);
private:
void InitMPI();
std::map<std::string, int> name_id_map;
};
class Meta {
public:
int src;
int dst;
MPI_Datatype datatype;
char *request;
int count;
int tag;
int device;
};
class MPISend {
public:
MPISend(const Meta &meta);
bool IsFinished();
bool IsReady();
void Send();
~MPISend();
private:
int done1_;
int done2_;
Meta *meta;
};
class MPIRecv {
public:
MPIRecv(const Meta &meta);
bool IsReady();
bool IsFinished();
void Recv();
~MPIRecv();
private:
int done1_;
int done2_;
Meta *meta;
};
} // namespace detail
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册