提交 94ce0202 编写于 作者: T tangwei12

mpi tools

上级 b1adcd46
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpi_client.h"
#include "mpi_utils.h"
namespace paddle {
namespace operators {
namespace detail {
bool MPIClient::AsyncSendVariable() {
char* msg = "123456787654";
int dst = 1;
MPIIsend send = MPIIsend(dst, msg);
}
bool MPIClient::Wait() {}
} // namespace detail
} // namespace operators
} // namespace paddle
\ No newline at end of file
...@@ -26,23 +26,26 @@ namespace operators { ...@@ -26,23 +26,26 @@ namespace operators {
namespace detail { namespace detail {
class MPIClient { class MPIClient {
public: public:
bool AsyncSendVariable(const std::string& ep, // bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx, // const platform::DeviceContext& ctx,
const framework::Scope& scope, // const framework::Scope& scope,
const std::string& var_name, // const std::string& var_name,
int64_t time_out = 600 * 1000); // int64_t time_out = 600 * 1000);
bool AsyncGetVariable(const std::string& ep, // bool AsyncGetVariable(const std::string& ep,
const platform::DeviceContext& ctx, // const platform::DeviceContext& ctx,
const framework::Scope& scope, // const framework::Scope& scope,
const std::string& var_name, // const std::string& var_name,
int64_t time_out = 600 * 1000); // int64_t time_out = 600 * 1000);
void AsyncSendBatchBarrier(const std::string& ep, // void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000); // int64_t time_out = 600 * 1000);
void AsyncSendFetchBarrier(const std::string& ep, // void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000); // int64_t time_out = 600 * 1000);
bool AsyncSendVariable();
bool Wait(); bool Wait();
private: private:
......
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
// //
#include <stdio.h> #include <stdio.h>
#include <string.h>
#include "paddle/fluid/operators/detail/mpi_utils.h" #include <mpi.h>
#include "mpi_utils.h"
#define max_worker_name_length 128 #define max_worker_name_length 128
#define mpi_tag = 2008
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,6 +45,47 @@ void MPIUtils::InitMPI() { ...@@ -42,6 +45,47 @@ void MPIUtils::InitMPI() {
MPI_Get_processor_name(host_name, &len) MPI_Get_processor_name(host_name, &len)
} }
}; };
MPIIsend::MPIIsend(int dst, const char* req) {
done1 = 0;
done2 = 0;
length = strlen(req);
req = req;
}
MPIIsend::Send() {
MPI_Isend(&req, length, MPI_CHAR, dst, mpi_tag, MPI_COMM_WORLD,
&msg1_);
MPI_Test(&msg1_, &done1_, MPI_STATUS_IGNORE)
}
bool MPIIsend::IsFinished() {
MPI_Status status;
if (!done1_) MPI_Test(&msg1_, &done1_, &status);
return done1;
}
MPIIsend::~MPIIsend(){
MPI_Wait(&msg1_, MPI_STATUS_IGNORE);
MPI_Free_mem(req);
}
MPIIrecv::MPIIrecv(){
}
MPIIrecv::Recv(){
}
MPIIrecv::IsFinished(){
}
MPIIrecv::~MPIIrecv(){
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
......
...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <mpi/mpi.h> #include <mpi.h>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -30,22 +30,23 @@ class MPIUtils { ...@@ -30,22 +30,23 @@ class MPIUtils {
class MPIIsend { class MPIIsend {
public: public:
void init(); MPIIsend(int dst, const char* buf);
int isFinished(); bool IsFinished();
void send(); void Send();
~MPIIsend(); ~MPIIsend();
private: private:
int done1; int done1;
int done2; int length;
sendrecv::VariableMessage req; char* req;
MPI_Request msg1_;
}; };
class MPIIrecv { class MPIIrecv {
public: public:
void init(); MPIIrecv();
int isFinished(); bool IsFinished();
void recv(); void Recv();
~MPIIrecv(); ~MPIIrecv();
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册