diff --git a/paddle/fluid/operators/detail/mpi_client.h b/paddle/fluid/operators/detail/mpi_client.h new file mode 100644 index 0000000000000000000000000000000000000000..14dcd678a09e87089d651d39c3233802aa592491 --- /dev/null +++ b/paddle/fluid/operators/detail/mpi_client.h @@ -0,0 +1,53 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" + +namespace paddle { +namespace operators { +namespace detail { +class MPIClient { + public: + bool AsyncSendVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = 600 * 1000); + + bool AsyncGetVariable(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = 600 * 1000); + + void AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = 600 * 1000); + + void AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out = 600 * 1000); + bool Wait(); + + private: + int64_t req_count_ = 0; +}; +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/mpi_server.h b/paddle/fluid/operators/detail/mpi_server.h new file mode 100644 index 0000000000000000000000000000000000000000..dda99318afa3d9f0f6b5a32b8c5d4e77e677ff9e --- /dev/null +++ b/paddle/fluid/operators/detail/mpi_server.h @@ -0,0 +1,23 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +namespace paddle { +namespace operators { +namespace detail { +class MPIServer { + public: + private: +}; +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/mpi_utils.cpp b/paddle/fluid/operators/detail/mpi_utils.cpp index adf4a3b92541b8f110f8bb1bf07cd269fd91a920..6560761e6b9f12f1f2b947817321e53d852ffb74 100644 --- a/paddle/fluid/operators/detail/mpi_utils.cpp +++ b/paddle/fluid/operators/detail/mpi_utils.cpp @@ -2,3 +2,47 @@ // Created by tangwei12 on 2018/3/27. // +#include + +#include "paddle/fluid/operators/detail/mpi_utils.h" + +#define max_worker_name_length 128 + +namespace paddle { +namespace operators { +namespace detail { +MPIUtils::MPIUtils(const std::string& worker_name) { + InitMPI(); + + int rank = 0, size = 1; + char my_name[max_work_group_size]; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + snprintf(my_name, max_worker_name_length, worker_name.c_str()); + + std::vector worker_names(size * max_worker_name_length); + MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, &worker_names[0], + max_worker_name_length, MPI_CHAR, MPI_COMM_WORLD); + for (int i = 0; i < number_of_procs; i++) { + name_to_id_[std::string(&worker_names[i * 128])] = i; + } +} + +void MPIUtils::InitMPI() { + int flag = 0; + MPI_CHECK(MPI_Initialized(&flag)); + + if (!flag) { + int rank = 0, size = 1, len = -1; + char host_name[max_worker_name_length]; + + MPI_Init(0, 0); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Get_processor_name(host_name, &len) + } +}; +} // namespace detail + +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/detail/mpi_utils.h b/paddle/fluid/operators/detail/mpi_utils.h index fb2f141246118c0430ec77be88ae90cecbb5829f..1f5ffdb18cf3b74ed556cc54645565914ba496e0 100644 --- a/paddle/fluid/operators/detail/mpi_utils.h +++ b/paddle/fluid/operators/detail/mpi_utils.h @@ -1,8 +1,54 @@ -// -// Created by tangwei12 on 2018/3/27. -// +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#ifndef PADDLE_MPI_UTILS_H -#define PADDLE_MPI_UTILS_H +#pragma once +#include +#include +#include +#include -#endif //PADDLE_MPI_UTILS_H +namespace paddle { +namespace operators { +namespace detail { +class MPIUtils { + public: + MPIUtils(const std::string& worker_name); + const int GetRankID(const std::string& task_id); + + private: + void InitMPI(); + std::map name_id_map; +}; + +class MPIIsend { + public: + void init(); + int isFinished(); + void send(); + ~MPIIsend(); + + private: + int done1; + int done2; + sendrecv::VariableMessage req; +}; + +class MPIIrecv { + public: + void init(); + int isFinished(); + void recv(); + ~MPIIrecv(); +}; + +} // namespace detail +} // namespace operators +} // namespace paddle