提交 94ce0202 编写于 作者: T tangwei12

mpi tools

上级 b1adcd46
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpi_client.h"
#include "mpi_utils.h"
namespace paddle {
namespace operators {
namespace detail {
bool MPIClient::AsyncSendVariable() {
char* msg = "123456787654";
int dst = 1;
MPIIsend send = MPIIsend(dst, msg);
}
bool MPIClient::Wait() {}
} // namespace detail
} // namespace operators
} // namespace paddle
\ No newline at end of file
......@@ -26,23 +26,26 @@ namespace operators {
namespace detail {
class MPIClient {
public:
bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncGetVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
// bool AsyncSendVariable(const std::string& ep,
// const platform::DeviceContext& ctx,
// const framework::Scope& scope,
// const std::string& var_name,
// int64_t time_out = 600 * 1000);
// bool AsyncGetVariable(const std::string& ep,
// const platform::DeviceContext& ctx,
// const framework::Scope& scope,
// const std::string& var_name,
// int64_t time_out = 600 * 1000);
// void AsyncSendBatchBarrier(const std::string& ep,
// int64_t time_out = 600 * 1000);
// void AsyncSendFetchBarrier(const std::string& ep,
// int64_t time_out = 600 * 1000);
bool AsyncSendVariable();
bool Wait();
private:
......
......@@ -3,10 +3,13 @@
//
#include <stdio.h>
#include <string.h>
#include "paddle/fluid/operators/detail/mpi_utils.h"
#include <mpi.h>
#include "mpi_utils.h"
#define max_worker_name_length 128
#define mpi_tag = 2008
namespace paddle {
namespace operators {
......@@ -42,6 +45,47 @@ void MPIUtils::InitMPI() {
MPI_Get_processor_name(host_name, &len)
}
};
MPIIsend::MPIIsend(int dst, const char* req) {
done1 = 0;
done2 = 0;
length = strlen(req);
req = req;
}
MPIIsend::Send() {
MPI_Isend(&req, length, MPI_CHAR, dst, mpi_tag, MPI_COMM_WORLD,
&msg1_);
MPI_Test(&msg1_, &done1_, MPI_STATUS_IGNORE)
}
bool MPIIsend::IsFinished() {
MPI_Status status;
if (!done1_) MPI_Test(&msg1_, &done1_, &status);
return done1;
}
MPIIsend::~MPIIsend(){
MPI_Wait(&msg1_, MPI_STATUS_IGNORE);
MPI_Free_mem(req);
}
MPIIrecv::MPIIrecv(){
}
MPIIrecv::Recv(){
}
MPIIrecv::IsFinished(){
}
MPIIrecv::~MPIIrecv(){
}
} // namespace detail
} // namespace operators
......
......@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mpi/mpi.h>
#include <mpi.h>
#include <map>
#include <string>
#include <vector>
......@@ -30,22 +30,23 @@ class MPIUtils {
class MPIIsend {
public:
void init();
int isFinished();
void send();
MPIIsend(int dst, const char* buf);
bool IsFinished();
void Send();
~MPIIsend();
private:
int done1;
int done2;
sendrecv::VariableMessage req;
int length;
char* req;
MPI_Request msg1_;
};
class MPIIrecv {
public:
void init();
int isFinished();
void recv();
MPIIrecv();
bool IsFinished();
void Recv();
~MPIIrecv();
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册