未验证 提交 f0afcabc 编写于 作者: X Xinger 提交者: GitHub

[WIP]飞桨PaddlePaddle 分布式强化学习功能研发 (#45998)

* add rpc module in cpp side

* add rpc module in python side

* support win32 and mac for rpc

* 代码优化

* 优化代码

* update rpc

* update rpc launch

* rpc remove rank and world_size api

* fix logger import bug

* remove support for win and mac

* remove support for xpu, npu, cinn and rocm

* remove support for xpu, npu, cinn and rocm

* fix shutdown barrier timeout bug

* update:python_rpc_handler to shared ptr

* fix master shutodwn first bug

* tests support for cpu

* update log to vlog

* update get service info api

* add single process test case

* remove process group

* remove some useless dependencies

* update rpc api comments

* update rpc comments: Example to Examples

* update rpc api comments

* update rpc api comments

* update launch api comments

* update init_rpc comments

* update rpc sync and async comments

* fix bug: init_rpc cant be called repeatly in a process

* update rpc api comment: make master endpoint unique

* update rpc api:service to worker, timeout_ms to timeout

* rename ServiceInfo to WorkerInfo

* refactor: rename server to worker, log to vlog

* add launch test

* remove unused codes

* refine
上级 8474392d
......@@ -42,7 +42,9 @@ set(DISTRIBUTE_COMPILE_FLAGS
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
if(LINUX)
add_subdirectory(rpc)
endif()
add_subdirectory(common)
add_subdirectory(ps)
add_subdirectory(test)
......
set(PADDLE_RPC_SRCS python_rpc_handler.cc rpc_agent.cc)
set_source_files_properties(
python_rpc_handler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(rpc_agent.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set(PADDLE_RPC_DEPS brpc protobuf glog pybind)
proto_library(paddle_rpc_proto SRCS rpc.proto)
cc_library(
paddle_rpc
SRCS ${PADDLE_RPC_SRCS}
DEPS ${PADDLE_RPC_DEPS} paddle_rpc_proto)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/pybind11.h>
#include <cassert>
#include <future>
#include <string>
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
namespace py = pybind11;
namespace paddle {
namespace distributed {
class FutureWrapper {
public:
FutureWrapper() {}
explicit FutureWrapper(std::future<std::string> fut) : fut_(std::move(fut)) {}
py::object wait() {
// GIL must be released, otherwise fut_.get() blocking will cause the
// service to fail to process RPC requests, leading to deadlock
PADDLE_ENFORCE_EQ(
PyGILState_Check(),
false,
platform::errors::Fatal(
"GIL must be released before fut.wait(), otherwise fut_.get() "
"blocking will cause the service to fail to "
"process RPC requests, leading to deadlock"));
auto s = fut_.get();
py::gil_scoped_acquire ag;
std::shared_ptr<PythonRpcHandler> python_handler =
PythonRpcHandler::GetInstance();
py::object obj = python_handler->Deserialize(py::bytes(s));
return obj;
}
private:
DISABLE_COPY_AND_ASSIGN(FutureWrapper);
std::future<std::string> fut_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"
namespace paddle {
namespace distributed {
constexpr auto kInternalModule = "paddle.distributed.rpc.internal";
py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
return fn;
}
PythonRpcHandler::PythonRpcHandler() {
py::gil_scoped_acquire ag;
// import python module
py::object rpc_internal = py::module::import(kInternalModule);
py_run_function_ = getFunction(rpc_internal, "_run_py_func");
py_serialize_ = getFunction(rpc_internal, "_serialize");
py_deserialize_ = getFunction(rpc_internal, "_deserialize");
}
py::object PythonRpcHandler::RunPythonFunc(const py::object& python_func) {
py::gil_scoped_acquire ag;
return py_run_function_(python_func);
}
std::string PythonRpcHandler::Serialize(const py::object& obj) {
py::gil_scoped_acquire ag;
py::object res = py_serialize_(obj);
return res.cast<std::string>();
}
py::object PythonRpcHandler::Deserialize(const std::string& obj) {
py::gil_scoped_acquire ag;
return py_deserialize_(py::bytes(obj));
}
std::shared_ptr<PythonRpcHandler> PythonRpcHandler::python_rpc_handler_ =
nullptr;
std::mutex PythonRpcHandler::lock_;
std::shared_ptr<PythonRpcHandler> PythonRpcHandler::GetInstance() {
if (python_rpc_handler_ == nullptr) {
std::lock_guard<std::mutex> guard(lock_);
if (python_rpc_handler_ == nullptr) {
python_rpc_handler_ = std::make_shared<PythonRpcHandler>();
return python_rpc_handler_;
}
}
return python_rpc_handler_;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/pybind11.h>
#include <memory>
#include <mutex>
#include <string>
#include "paddle/fluid/platform/macros.h"
namespace py = pybind11;
namespace paddle {
namespace distributed {
class PYBIND11_EXPORT PythonRpcHandler {
public:
PythonRpcHandler();
~PythonRpcHandler() = default;
static std::shared_ptr<PythonRpcHandler> GetInstance();
// Run a pickled Python function and return the result py::object
py::object RunPythonFunc(const py::object& python_func);
// Serialized a py::object into a string
std::string Serialize(const py::object& obj);
// Deserialize a string into a py::object
py::object Deserialize(const std::string& obj);
private:
DISABLE_COPY_AND_ASSIGN(PythonRpcHandler);
static std::shared_ptr<PythonRpcHandler> python_rpc_handler_;
// Ref to `paddle.distributed.rpc.internal.run_py_func`.
py::object py_run_function_;
// Ref to `paddle.distributed.rpc.internal.serialize`.
py::object py_serialize_;
// Ref to `paddle.distributed.rpc.internal.deserialize`.
py::object py_deserialize_;
// Lock to protect initialization.
static std::mutex lock_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax="proto2";
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;
message RpcRequest {
required bytes message = 1;
};
message RpcResponse {
required bytes message = 1;
};
service RpcBaseService {
rpc Send(RpcRequest) returns (RpcResponse);
rpc InvokeRpc(RpcRequest) returns (RpcResponse);
};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/rpc/rpc_agent.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
const int kTimeoutMs = 500000;
const int kConnectTimeoutMs = 10000;
const int kMaxRetry = 5;
const int kCloseWaitMs = 1000;
std::shared_ptr<RpcAgent> RpcAgent::rpc_agent_instance_ = nullptr;
RpcAgent::RpcAgent(std::string name, std::vector<WorkerInfo> infos) {
name_ = std::move(name);
for (auto info : infos) {
name_to_infos_.insert({info.name_, info});
id_to_infos_.insert({info.id_, info});
}
this->infos_ = std::move(infos);
auto it = name_to_infos_.find(name_);
this->rank_ = it->second.id_;
rpc_service_ = std::make_shared<RpcService>();
PADDLE_ENFORCE_EQ(
server_.AddService(rpc_service_.get(), brpc::SERVER_DOESNT_OWN_SERVICE),
0,
platform::errors::Fatal("Fail to add service: %s", name));
}
int RpcAgent::StartWorker() {
auto info = GetWorkerInfo(name_);
// Start the server.
int port = info.port_;
brpc::ServerOptions options;
PADDLE_ENFORCE_EQ(server_.Start(port, &options),
0,
platform::errors::Fatal("Fail to start worker: %s", name_));
VLOG(0) << "Start worker : " << name_;
return 0;
}
int RpcAgent::StartClient() {
// Initialize the channel, NULL means using default options.
brpc::ChannelOptions channel_options;
channel_options.protocol = "baidu_std";
channel_options.timeout_ms = kTimeoutMs;
channel_options.connection_type = "pooled";
channel_options.connect_timeout_ms = kConnectTimeoutMs;
channel_options.max_retry = kMaxRetry;
channels_.resize(name_to_infos_.size());
// build connection from client to all servers
for (std::size_t i = 0; i < channels_.size(); i++) {
auto info = id_to_infos_.find(i)->second;
channels_[i].reset(new brpc::Channel());
PADDLE_ENFORCE_EQ(
channels_[i]->Init(info.ip_.c_str(), info.port_, &channel_options),
0,
platform::errors::Fatal(
"Fail to initialize channel: %d, ip: %s, port: %d",
i,
info.ip_,
info.port_));
}
VLOG(0) << "Init Channels: " << name_;
return 0;
}
int RpcAgent::Stop() {
VLOG(0) << "Worker: " << name_ << " is going to stop.";
server_.Stop(kCloseWaitMs);
server_.Join();
rpc_agent_instance_ = nullptr;
VLOG(0) << "Worker: " << name_ << " has stopped";
return 0;
}
void OnRpcDone::Run() {
// delete this after Run
std::unique_ptr<OnRpcDone> self_guard(this);
PADDLE_ENFORCE_EQ(
cntl_.Failed(), false, platform::errors::Fatal(cntl_.ErrorText()));
promise_->set_value(response_.message());
VLOG(2) << "Received response from " << cntl_.remote_side() << " to "
<< cntl_.local_side() << " (attached=" << cntl_.response_attachment()
<< ")"
<< " latency=" << cntl_.latency_us() << "us";
}
std::future<std::string> RpcAgent::InvokeRpc(const std::string &py_func,
const std::string &to,
int timeout_ms = kTimeoutMs) {
auto it = name_to_infos_.find(to);
PADDLE_ENFORCE_NE(
it,
name_to_infos_.end(),
platform::errors::OutOfRange("Worker %s doesn't exist!", to));
uint32_t id = it->second.id_;
auto channel = channels_[id];
// `done` must be allocated on the heap because its life cycle is after
// calling done.Run().
OnRpcDone *done = new OnRpcDone;
done->cntl_.set_timeout_ms(timeout_ms);
done->request_.set_message(py_func);
std::future<std::string> fut = done->GetFuture();
RpcBaseService_Stub stub(channel.get());
stub.InvokeRpc(&done->cntl_, &done->request_, &done->response_, done);
return fut;
}
std::shared_ptr<RpcAgent> RpcAgent::RpcAgentInstance() {
PADDLE_ENFORCE_NE(rpc_agent_instance_,
nullptr,
platform::errors::Fatal(
"RpcAgent is not set, please calling "
"paddle.distributed.rpc.int_rpc() to init rpc agent."));
return rpc_agent_instance_;
}
void RpcAgent::SetAgentInstance(std::shared_ptr<RpcAgent> agent) {
PADDLE_ENFORCE_EQ(
rpc_agent_instance_,
nullptr,
platform::errors::Fatal(
"RpcAgent has been set, please don't set rpc agent repeatly."));
rpc_agent_instance_ = agent;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"
#include "paddle/fluid/distributed/rpc/rpc.pb.h"
#include "paddle/fluid/distributed/rpc/rpc_service.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
struct WorkerInfo {
std::string name_;
uint32_t id_;
std::string ip_;
uint32_t port_;
WorkerInfo(std::string name, uint32_t id, std::string ip, uint32_t port)
: name_(std::move(name)), id_(id), ip_(std::move(ip)), port_(port) {}
std::string to_string() const {
std::string info = "{name: " + name_ + ", rank: " + std::to_string(id_) +
", ip: " + ip_ + ", port: " + std::to_string(port_) +
"}";
return info;
}
};
class OnRpcDone : public google::protobuf::Closure {
public:
OnRpcDone() { promise_ = std::make_shared<std::promise<std::string>>(); }
// process callback of response
void Run();
std::future<std::string> GetFuture() {
return std::future<std::string>(promise_->get_future());
}
RpcResponse response_;
RpcRequest request_;
brpc::Controller cntl_;
std::shared_ptr<std::promise<std::string>> promise_;
};
class RpcAgent {
public:
static std::shared_ptr<RpcAgent> RpcAgentInstance();
static void SetAgentInstance(std::shared_ptr<RpcAgent> agent);
// init RpcAgent instance and get information of all services
RpcAgent(std::string name, std::vector<WorkerInfo> infos);
~RpcAgent() {}
const WorkerInfo &GetWorkerInfo(const std::string &name) const {
auto it = name_to_infos_.find(name);
return it->second;
}
const WorkerInfo &GetWorkerInfoById(uint32_t id) const {
auto it = id_to_infos_.find(id);
return it->second;
}
const WorkerInfo &GetCurrentWorkerInfo() const {
return GetWorkerInfo(name_);
}
const std::vector<WorkerInfo> &GetAllWorkerInfos() const {
return this->infos_;
}
uint32_t Rank() { return this->rank_; }
uint32_t WorldSize() { return infos_.size(); }
int StartWorker();
// build connection from client to all servers
int StartClient();
int Stop();
std::future<std::string> InvokeRpc(const std::string &msg,
const std::string &to,
int timeout_ms);
private:
DISABLE_COPY_AND_ASSIGN(RpcAgent);
static std::shared_ptr<RpcAgent> rpc_agent_instance_;
brpc::Server server_;
std::shared_ptr<RpcService> rpc_service_;
std::vector<std::shared_ptr<brpc::Channel>> channels_;
std::string name_;
uint32_t rank_;
std::unordered_map<std::string, WorkerInfo> name_to_infos_;
std::unordered_map<uint32_t, WorkerInfo> id_to_infos_;
std::vector<WorkerInfo> infos_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <brpc/server.h>
#include <string>
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"
#include "paddle/fluid/distributed/rpc/rpc.pb.h"
namespace paddle {
namespace distributed {
class RpcService : public RpcBaseService {
public:
RpcService() {}
virtual ~RpcService() {}
virtual void InvokeRpc(google::protobuf::RpcController *cntl_base,
const RpcRequest *request,
RpcResponse *response,
google::protobuf::Closure *done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
VLOG(2) << "InvokeRpc API: Received request[log_id=" << cntl->log_id()
<< "] from " << cntl->remote_side() << " to " << cntl->local_side()
<< ": "
<< " (attached=" << cntl->request_attachment() << ")";
std::string py_func_str = request->message();
std::shared_ptr<PythonRpcHandler> python_handler =
PythonRpcHandler::GetInstance();
// acquire gil, because native Python objects are used
py::gil_scoped_acquire ag;
py::object py_func_obj = python_handler->Deserialize(py_func_str);
py::object res = python_handler->RunPythonFunc(py_func_obj);
std::string res_str = python_handler->Serialize(res);
response->set_message(res_str);
}
};
} // namespace distributed
} // namespace paddle
......@@ -49,6 +49,14 @@ if(WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} graph_gpu_wrapper)
endif()
endif()
if(WITH_DISTRIBUTE
AND LINUX
AND NOT WITH_ASCEND_CL
AND NOT WITH_XPU
AND NOT WITH_CINN
AND NOT WITH_ROCM)
set(PYBIND_DEPS ${PYBIND_DEPS} paddle_rpc)
endif()
if(WITH_GPU OR WITH_ROCM)
set(PYBIND_DEPS ${PYBIND_DEPS} dynload_cuda)
set(PYBIND_DEPS ${PYBIND_DEPS} cuda_device_guard)
......@@ -218,6 +226,29 @@ if(WITH_PSCORE)
set(PYBIND_SRCS fleet_py.cc ${PYBIND_SRCS})
endif()
if(WITH_DISTRIBUTE
AND LINUX
AND NOT WITH_ASCEND_CL
AND NOT WITH_XPU
AND NOT WITH_CINN
AND NOT WITH_ROCM)
if(WITH_ARM_BRPC)
set(DISTRIBUTE_COMPILE_FLAGS
"-faligned-new -Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result"
)
else()
set(DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result"
)
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
endif()
set_source_files_properties(rpc.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set(PYBIND_SRCS rpc.cc ${PYBIND_SRCS})
endif()
if(WITH_NCCL OR WITH_RCCL)
list(APPEND PYBIND_SRCS nccl_wrapper_py.cc)
endif()
......
......@@ -182,6 +182,12 @@ limitations under the License. */
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#endif
#if defined(__linux__) && !defined(PADDLE_WITH_XPU) && \
!defined(PADDLE_WITH_ASCEND_CL) && !defined(PADDLE_WITH_CINN) && \
!defined(PADDLE_WITH_HIP)
#include "paddle/fluid/pybind/rpc.h"
#endif
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/pybind/eager_utils.h"
......@@ -2602,6 +2608,21 @@ All parameter, weight, gradient are variables in Paddle.
BindGraphGpuWrapper(&m);
#endif
#endif
#if defined(__linux__) && !defined(PADDLE_WITH_XPU) && \
!defined(PADDLE_WITH_ASCEND_CL) && !defined(PADDLE_WITH_CINN) && \
!defined(PADDLE_WITH_HIP)
BindWorkerInfo(&m);
BindFuture(&m);
InitAndSetAgentInstance(&m);
InvokeRpc(&m);
StartWorker(&m);
StartClient(&m);
StopWorker(&m);
GetWorkerInfo(&m);
GetWorkerInfoByRank(&m);
GetCurrentWorkerInfo(&m);
GetAllWorkerInfos(&m);
#endif
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/rpc.h"
#include "paddle/fluid/distributed/rpc/future_wrapper.h"
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"
#include "paddle/fluid/distributed/rpc/rpc_agent.h"
namespace py = pybind11;
using paddle::distributed::FutureWrapper;
using paddle::distributed::PythonRpcHandler;
using paddle::distributed::RpcAgent;
using paddle::distributed::WorkerInfo;
namespace paddle {
namespace pybind {
void BindWorkerInfo(py::module* m) {
py::class_<WorkerInfo>(*m, "WorkerInfo")
.def(py::init<std::string, uint32_t, std::string, uint32_t>())
.def_readonly("name", &WorkerInfo::name_)
.def_readonly("rank", &WorkerInfo::id_)
.def_readonly("ip", &WorkerInfo::ip_)
.def_readonly("port", &WorkerInfo::port_)
.def("__str__", &WorkerInfo::to_string)
.def("__repr__", &WorkerInfo::to_string);
}
void BindFuture(py::module* m) {
py::class_<FutureWrapper, std::shared_ptr<FutureWrapper>>(*m, "Future")
.def(py::init<>())
.def("wait",
&FutureWrapper::wait,
py::call_guard<py::gil_scoped_release>());
}
void InitAndSetAgentInstance(py::module* m) {
m->def(
"init_and_set_agent_instance",
[](const std::string& name, const std::vector<WorkerInfo>& infos) {
auto instance = std::make_shared<RpcAgent>(name, infos);
instance->SetAgentInstance(instance);
},
py::call_guard<py::gil_scoped_release>(),
py::arg("name"),
py::arg("infos"));
}
void InvokeRpc(py::module* m) {
m->def(
"invoke_rpc",
[](const std::string& name, const std::string& py_func, int timeout_ms) {
auto instance = RpcAgent::RpcAgentInstance();
return std::make_shared<FutureWrapper>(
instance->InvokeRpc(py_func, name, timeout_ms));
},
py::call_guard<py::gil_scoped_release>(),
py::arg("to"),
py::arg("py_func"),
py::arg("timeout_ms"));
}
void StartWorker(py::module* m) {
m->def(
"rpc_start_worker",
[]() {
auto instance = RpcAgent::RpcAgentInstance();
instance->StartWorker();
},
py::call_guard<py::gil_scoped_release>());
}
void StartClient(py::module* m) {
m->def(
"rpc_start_client",
[]() {
auto instance = RpcAgent::RpcAgentInstance();
instance->StartClient();
},
py::call_guard<py::gil_scoped_release>());
}
void StopWorker(py::module* m) {
m->def(
"rpc_stop_worker",
[]() {
auto instance = RpcAgent::RpcAgentInstance();
instance->Stop();
},
py::call_guard<py::gil_scoped_release>());
}
void GetWorkerInfo(py::module* m) {
m->def(
"rpc_get_worker_info",
[](const std::string& name) {
auto instance = RpcAgent::RpcAgentInstance();
return instance->GetWorkerInfo(name);
},
py::call_guard<py::gil_scoped_release>(),
py::arg("name"));
}
void GetWorkerInfoByRank(py::module* m) {
m->def(
"rpc_get_worker_info_by_rank",
[](uint32_t rank) {
auto instance = RpcAgent::RpcAgentInstance();
return instance->GetWorkerInfoById(rank);
},
py::call_guard<py::gil_scoped_release>(),
py::arg("rank"));
}
void GetCurrentWorkerInfo(py::module* m) {
m->def(
"rpc_get_current_worker_info",
[]() {
auto instance = RpcAgent::RpcAgentInstance();
return instance->GetCurrentWorkerInfo();
},
py::call_guard<py::gil_scoped_release>());
}
void GetAllWorkerInfos(py::module* m) {
m->def(
"rpc_get_all_worker_infos",
[]() {
auto instance = RpcAgent::RpcAgentInstance();
return instance->GetAllWorkerInfos();
},
py::call_guard<py::gil_scoped_release>());
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindWorkerInfo(py::module* m);
void BindFuture(py::module* m);
void InitAndSetAgentInstance(py::module* m);
void InvokeRpc(py::module* m);
void StartWorker(py::module* m);
void StartClient(py::module* m);
void StopWorker(py::module* m);
void GetWorkerInfo(py::module* m);
void GetWorkerInfoByRank(py::module* m);
void GetCurrentWorkerInfo(py::module* m);
void GetAllWorkerInfos(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -68,6 +68,8 @@ from . import cloud_utils # noqa: F401
from .sharding import * # noqa: F401
from . import rpc
__all__ = [ # noqa
"spawn", "launch", "scatter", "broadcast", "ParallelEnv", "new_group",
"init_parallel_env", "gloo_init_parallel_env", "gloo_barrier",
......@@ -76,5 +78,5 @@ __all__ = [ # noqa
"all_gather_object", "InMemoryDataset", "barrier", "all_reduce", "alltoall",
"send", "reduce", "recv", "ReduceOp", "wait", "get_rank",
"ProbabilityEntry", "ParallelMode", "is_initialized", "isend", "irecv",
"reduce_scatter"
"reduce_scatter", "rpc"
]
......@@ -18,12 +18,14 @@ from .collective import CollectiveController
from .collective import CollectiveElasticController
from .ps import PSController
from .ipu_controller import IPUController
from .rpc import RpcController
# the order is extremely important
_controllers = [
IPUController,
CollectiveElasticController,
PSController,
RpcController,
CollectiveController,
]
......
......@@ -28,6 +28,7 @@ class ControleMode:
COLLECTIVE = "collective"
PS = "ps"
IPU = "ipu"
RPC = "rpc"
class ControllerBase(object):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .controller import Controller, ControleMode
import json
class RpcController(Controller):
@classmethod
def enable(cls, ctx):
if ctx.args.run_mode == ControleMode.RPC:
ctx.logger.debug("{} enabled".format(cls.__name__))
return True
else:
return False
def build_pod(self):
assert (self.ctx.args.master
is not None), "Master is None, Please set master address!"
self._build_pod_with_master()
def _build_pod_with_master(self):
# nproc_per_node
self.pod.replicas = self.pod_replicas()
# rank will be reset when restart
self.pod.rank = int(self.ctx.args.rank)
port = self.ctx.node.get_free_port()
# compatible
endpoints = [
"{}:{}".format(self.ctx.node.ip, p)
for p in self.ctx.node.get_free_ports(self.pod.replicas)
]
data = json.dumps({
"name": self.pod.name,
"rank": self.pod.rank,
"replicas": self.pod.replicas,
"dtype": self.ctx.node.device.dtype,
"candidate": "{}:{}".format(self.ctx.node.ip, port),
"endpoints": ",".join(endpoints),
})
peer_list, rank = self.master.sync_peers(
"/{}/info".format(self.job.id),
self.pod.name,
data,
self.job.replicas,
self.pod.rank,
)
self.pod.rank = rank
if len(peer_list) < 1:
return False
peer_list = [json.loads(i) for i in peer_list]
self.ctx.logger.debug("sync peers done {}".format(peer_list))
self.save_pod_log(peer_list)
global_size = sum([i["replicas"] for i in peer_list])
rank_offset = sum([i["replicas"] for i in peer_list[:rank]])
rpc_master = peer_list[0]["candidate"]
self.pod.reset()
for i in range(self.pod.replicas):
e = {
"PADDLE_MASTER_ENDPOINT": rpc_master,
"PADDLE_WORKER_ENDPOINT": endpoints[i],
"PADDLE_TRAINER_ID": "{}".format(i + rank_offset),
"PADDLE_TRAINERS_NUM": "{}".format(global_size),
}
log_file = f"workerlog.{i + rank_offset}"
self.add_container(envs=e, log_file=log_file)
return True
......@@ -48,7 +48,7 @@ def launch():
- ``--log_dir``: The path for each process's log. e.g., ``--log_dir=output_dir``. Default ``--log_dir=log``.
- ``--run_mode``: The run mode of job, can be:collective/ps/ps-heter. e.g., ``--run_mode=ps``. Default ``--run_mode=collective``.
- ``--run_mode``: The run mode of job, can be:collective/ps/ps-heter/rpc. e.g., ``--run_mode=ps``. Default ``--run_mode=collective``.
- ``--job_id``: The job unique id, it affects the log files' name. e.g., ``--job_id=job1``. Default ``--job_id=default``.
......@@ -260,6 +260,27 @@ def launch():
# Please Check the `IPU Parameters` for details
python -m paddle.distributed.launch --devices 4 ipu --hosts=localhost --nproc_per_host=2 --ipus_per_replica=1 --ipu_partition=pod16 --vipu_server=127.0.0.1 train.py
Examples 11 (rpc, cpu, single node):
.. code-block:: bash
:name: code-block-example-bash11
# Training on single node with two local servers
python -m paddle.distributed.launch --master 127.0.0.1:8765 --nnodes 1 --nproc_per_node 2 --rank 0 --run_mode rpc train.py
Examples 12 (rpc, cpu, multi node):
.. code-block:: bash
:name: code-block-example-bash12
# For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 2 servers.
# On 192.168.0.16
python -m paddle.distributed.launch --master 192.168.0.16:8765 --nnodes 2 --nproc_per_node 2 --rank 0 --run_mode rpc train.py
# On 192.168.0.17
python -m paddle.distributed.launch --master 192.168.0.16:8765 --nnodes 2 --nproc_per_node 2 --rank 1 --run_mode rpc train.py
"""
# initialize the context to run
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.rpc.rpc import (
init_rpc,
shutdown,
rpc_async,
rpc_sync,
get_worker_info,
get_all_worker_infos,
get_current_worker_info,
)
__all__ = [
"init_rpc",
"shutdown",
"rpc_async",
"rpc_sync",
"get_worker_info",
"get_all_worker_infos",
"get_current_worker_info",
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import pickle
PythonFunc = namedtuple("PythonFunc", ["func", "args", "kwargs"])
"""Some Python code interfaces called in C++"""
def _serialize(obj):
return pickle.dumps(obj)
def _deserialize(obj):
return pickle.loads(obj)
def _run_py_func(python_func):
result = python_func.func(*python_func.args, **python_func.kwargs)
return result
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import namedtuple
import pickle
import time
import datetime
import paddle.fluid.core as core
from paddle.distributed.utils.launch_utils import logger
from paddle.distributed.rpc.internal import _serialize, PythonFunc
from paddle.distributed.launch.context import Node
WorkerInfo = namedtuple("WorkerInfo", ["name", "rank", "ip", "port"])
_DEFAULT_RPC_TIMEOUT = -1
_MAX_RPC_TIMEOUT_MS = 0x7fffffff
_BARRIER_TIMEOUT_MAX_DAYS = 99999999
# tcp store for `_barrier_never_timeout`
_barrier_store = None
# count the number of `_barrier_never_timeout` is called and
# ensure that the barrier key is unique
_barrier_count = 0
def _set_barrier_store(store):
global _barrier_store
_barrier_store = store
def _del_barrier_store():
global _barrier_store
del _barrier_store
def _set_self_info(name, rank, ip, port):
self_info = pickle.dumps(WorkerInfo(name, rank, ip, port))
_barrier_store.set(str(rank), self_info)
def _exchange_all_service_infos(world_size):
all_infos = []
s = set()
for rank in range(world_size):
info = pickle.loads(_barrier_store.get(str(rank)))
assert (info.name not in s
), "The Worker name must be unique, but name `{}` is repeated."
s.add(info.name)
all_infos.append(info)
return all_infos
def _gen_endpoint():
node = Node()
ip = node.get_host_ip()
free_port = node.get_free_port()
return "{}:{}".format(ip, free_port)
def init_rpc(name, rank=None, world_size=None, master_endpoint=None):
"""
init rpc.
Args:
name (str): worker name.
rank (int, optional): worker id, default is None.
world_size (int, optional): number of workers, default is None.
master_endpoint (str, optional): id address of master, other nodes communicate with the master to
get the information of all worker nodes, default is None.
Returns:
None.
Examples:
.. code-block:: python
import paddle.distributed.rpc as rpc
rpc.init_rpc("worker0", rank=0, world_size=1,
master_endpoint="127.0.0.1:8001")
rpc.shutdown()
"""
rank = int(os.environ["PADDLE_TRAINER_ID"]) if rank is None else rank
world_size = int(
os.environ["PADDLE_TRAINERS_NUM"]) if world_size is None else world_size
worker_endpoint = os.getenv("PADDLE_WORKER_ENDPOINT", None)
if worker_endpoint is None:
worker_endpoint = _gen_endpoint()
logger.info("Trainer {}: worker endpoint: {}".format(rank, worker_endpoint))
master_endpoint = (master_endpoint if master_endpoint != None else
os.environ["PADDLE_MASTER_ENDPOINT"])
master_addr, master_port = master_endpoint.split(":")
master_port = int(master_port)
stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
store = core.TCPStore(master_addr,
master_port,
rank == 0,
world_size,
timeout=stop_check_timeout)
_set_barrier_store(store)
ip, port = worker_endpoint.split(":")
port = int(port)
_set_self_info(name, rank, ip, port)
all_infos = _exchange_all_service_infos(world_size)
c_infos = []
for node_info in all_infos:
info = core.WorkerInfo(node_info.name, node_info.rank, node_info.ip,
node_info.port)
c_infos.append(info)
core.init_and_set_agent_instance(name, c_infos)
core.rpc_start_worker()
# ensure that all the workers are started
_barrier_never_timeout(rank, world_size)
core.rpc_start_client()
logger.info("Trainer {}: Init RPC done!".format(rank))
def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT):
"""
Make a blocking RPC call to run function ``fn`` on worker ``to``.
Args:
to (str): name of the destination worker.
fn (fn): a callable function, such as Python callables.
args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None.
kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn``
invocation, default is None.
timeout (int, optional): timeout in seconds to use for this RPC. If
the RPC does not complete in this amount of
time, an exception indicating it has
timed out will be raised. A value less than or equal to 0
indicates an infinite timeout, i.e. a timeout
error will never be raised. The default value is -1.
Returns:
Returns the result of running ``fn`` with ``args`` and ``kwargs``.
Examples:
.. code-block:: python
import paddle.distributed.rpc as rpc
def add(a, b):
return a + b
rpc.init_rpc("worker0", rank=0, world_size=1,
master_endpoint="127.0.0.1:8002")
ret = rpc.rpc_sync("worker0", add, args=(2, 3))
rpc.shutdown()
"""
fut = _invoke_rpc(to, fn, args, kwargs, timeout)
return fut.wait()
def rpc_async(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT):
"""
Make a non-blocking RPC call to run function ``fn`` on worker ``to``.
Args:
to (str): name of the destination worker.
fn (fn): a callable function, such as Python callables.
args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None.
kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn``
invocation, default is None.
timeout (int, optional): timeout in seconds to use for this RPC. If
the RPC does not complete in this amount of
time, an exception indicating it has
timed out will be raised. A value less than or equal to 0
indicates an infinite timeout, i.e. a timeout
error will never be raised. The default value is -1.
Returns:
Returns a :class:`FutureWrapper` object that can be waited
on. When completed, the return value of ``fn`` on ``args`` and
``kwargs`` can be got by `fut.wait()`.
Examples:
.. code-block:: python
import paddle.distributed.rpc as rpc
def add(a, b):
return a + b
rpc.init_rpc("worker0", rank=0, world_size=1,
master_endpoint="127.0.0.1:8003")
fut = rpc.rpc_async("worker0", add, args=(2, 3))
print(fut.wait())
rpc.shutdown()
"""
return _invoke_rpc(to, fn, args, kwargs, timeout)
def _invoke_rpc(to, fn, args, kwargs, timeout):
args = args if args else ()
kwargs = kwargs if kwargs else {}
serial_obj = _serialize(PythonFunc(fn, args, kwargs))
timeout_ms = timeout * 1000
timeout_ms = _MAX_RPC_TIMEOUT_MS if timeout_ms <= 0 else timeout_ms
future = core.invoke_rpc(to, serial_obj, timeout_ms)
return future
def _barrier_never_timeout(global_rank, global_world_size):
# max timeout
timeout = datetime.timedelta(days=_BARRIER_TIMEOUT_MAX_DAYS)
if global_world_size < 2:
return
global _barrier_count
barrier_prefix = "Barrier/" + str(_barrier_count) + "/"
_barrier_count += 1
is_master = (global_rank == 0)
def _check_keys_ready(wait_keys):
start_time = time.time()
while len(wait_keys) > 0:
time.sleep(0.1)
elapse_time = time.time() - start_time
if datetime.timedelta(seconds=elapse_time) > timeout:
raise RuntimeError(
"Keys {} are not ready sinck rank {} is waiting them.".
format(wait_keys, global_rank))
wait_keys = list(
filter(lambda key: int(_barrier_store.get(key)) != 1,
wait_keys))
if is_master:
# the master will add key, wait for all workers'exiting key and exit in the end.
# Note: the master must exit in the end to ensure that the TcpServer is destroyed in the end.
wait_keys = [
barrier_prefix + str(rank) for rank in range(1, global_world_size)
]
_barrier_store.add(barrier_prefix + str(0), 1)
_check_keys_ready(wait_keys)
else:
wait_keys = [barrier_prefix + str(0)]
_check_keys_ready(wait_keys)
_barrier_store.add(barrier_prefix + str(global_rank), 1)
def shutdown():
"""
Perform a shutdown of the RPC agent, stop the worker and destroy the agent.
This will block until all local and remote RPC processes reach this method
and wait for all outstanding work to complete.
Returns:
None.
Examples:
.. code-block:: python
import paddle.distributed.rpc as rpc
rpc.init_rpc("worker0", rank=0, world_size=1,
master_endpoint="127.0.0.1:8004")
rpc.shutdown()
"""
info = get_current_worker_info()
rank = info.rank
world_size = len(get_all_worker_infos())
# master will exit in the end
_barrier_never_timeout(rank, world_size)
core.rpc_stop_worker()
_del_barrier_store()
logger.info("Trainer {}: rpc shutdown!".format(rank))
def get_worker_info(name):
"""
Get worker information by worker name.
Args:
name (str): name of the worker.
Returns:
class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`.
Examples:
.. code-block:: python
import paddle.distributed.rpc as rpc
import os
os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9002"
rpc.init_rpc("worker0", rank=0, world_size=1,
master_endpoint="127.0.0.1:8005")
print(rpc.get_worker_info("worker0"))
# {name: worker0, rank: 0, ip: 127.0.0.1, port: 9002}
rpc.shutdown()
"""
return core.rpc_get_worker_info(name)
def get_all_worker_infos():
"""
Get all worker informations.
Returns:
List[WorkerInfo].
Examples:
.. code-block:: python
import paddle.distributed.rpc as rpc
import os
os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9003"
rpc.init_rpc("worker0", rank=0, world_size=1,
master_endpoint="127.0.0.1:8006")
print(rpc.get_all_worker_infos())
# [{name: worker0, rank: 0, ip: 127.0.0.1, port: 9003}]
rpc.shutdown()
"""
return core.rpc_get_all_worker_infos()
def get_current_worker_info():
"""
Get current worker information.
Returns:
class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`.
Examples:
.. code-block:: python
import paddle.distributed.rpc as rpc
import os
os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9004"
rpc.init_rpc("worker0", rank=0, world_size=1,
master_endpoint="127.0.0.1:8007")
print(rpc.get_current_worker_info())
# {name: worker0, rank: 0, ip: 127.0.0.1, port: 9004}
rpc.shutdown()
"""
return core.rpc_get_current_worker_info()
......@@ -609,6 +609,7 @@ if(WITH_DISTRIBUTE)
add_subdirectory(ps)
add_subdirectory(auto_parallel)
add_subdirectory(collective)
add_subdirectory(rpc)
# FIXME(typhoonzero): add these tests back
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer")
......
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
list(APPEND TEST_OPS ${TEST_OP})
set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 50)
endforeach()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.distributed as dist
paddle.device.set_device("cpu")
def add(a, b):
a = paddle.to_tensor(a, dtype="float32")
b = paddle.to_tensor(b, dtype="float32")
res = paddle.add(a, b).numpy()
return res
def rpc_add(to, args):
res = dist.rpc.rpc_sync(to, add, args=args)
return res
def worker_name(rank):
return "worker{}".format(rank)
def main():
rank = dist.get_rank()
world_size = dist.get_world_size()
dist.rpc.init_rpc(worker_name(rank))
if rank == 0:
mmap_data1 = np.memmap(
"rpc_launch_data1.npy",
dtype=np.float32,
mode="r",
shape=(10 * world_size, 100),
)
mmap_data2 = np.memmap(
"rpc_launch_data2.npy",
dtype=np.float32,
mode="r",
shape=(10 * world_size, 100),
)
mmap_out = np.memmap(
"rpc_launch_result.npy",
dtype=np.float32,
mode="w+",
shape=(10 * world_size, 100),
)
for i in range(world_size):
a = mmap_data1[i * 10:(i + 1) * 10, :]
b = mmap_data2[i * 10:(i + 1) * 10, :]
args = (a, b)
out = rpc_add(worker_name(i), args)
mmap_out[i * 10:(i + 1) * 10, :] = out[:]
dist.rpc.shutdown()
if __name__ == "__main__":
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
import paddle.distributed as dist
import numpy as np
from test_rpc_base import RpcTestBase, RpcLaunchTestBase
paddle.device.set_device("cpu")
def worker_name(rank):
return "worker{}".format(rank)
def paddle_add(a, b):
a = paddle.to_tensor(a)
b = paddle.to_tensor(b)
res = paddle.add(a, b).numpy()
return res
class TestMultiProcessRpc(RpcTestBase):
def test_one_server_sync_paddle_add(self):
a = np.random.random((10, 100))
b = np.random.random((10, 100))
res = np.add(a, b)
args = (a, b)
queues = self.run_rpc(True, 1, paddle_add, args)
out = queues[0].get()
np.testing.assert_allclose(out, res, rtol=1e-05)
def test_one_server_async_paddle_add(self):
a = np.random.random((10, 100))
b = np.random.random((10, 100))
res = np.add(a, b)
args = (a, b)
queues = self.run_rpc(False, 1, paddle_add, args)
out = queues[0].get()
np.testing.assert_allclose(out, res, rtol=1e-05)
def test_two_server_sync_paddle_add(self):
a = np.random.random((10, 100))
b = np.random.random((10, 100))
res = np.add(a, b)
args = (a, b)
queues = self.run_rpc(True, 2, paddle_add, args)
out1 = queues[0].get()
out2 = queues[1].get()
np.testing.assert_allclose(out1, res, rtol=1e-05)
np.testing.assert_allclose(out2, res, rtol=1e-05)
def test_two_server_async_paddle_add(self):
a = np.random.random((10, 100))
b = np.random.random((10, 100))
res = np.add(a, b)
args = (a, b)
queues = self.run_rpc(False, 2, paddle_add, args)
out1 = queues[0].get()
out2 = queues[1].get()
np.testing.assert_allclose(out1, res, rtol=1e-05)
np.testing.assert_allclose(out2, res, rtol=1e-05)
class TestSingleProcessRpc(RpcTestBase):
def setUp(self):
self._port_set = set()
master_endpoint = "127.0.0.1:{}".format(self._find_free_port())
dist.rpc.init_rpc(worker_name(0), 0, 1, master_endpoint)
print("Single Process RPC setUp...")
def tearDown(self):
dist.rpc.shutdown()
print("Single Process RPC tearDown...")
def test_sync_rpc_paddle_add(self):
a = np.random.random((10, 100))
b = np.random.random((10, 100))
res = np.add(a, b)
args = (a, b)
out = dist.rpc.rpc_sync(worker_name(0), paddle_add, args=args)
np.testing.assert_allclose(out, res, rtol=1e-05)
def test_async_rpc_paddle_add(self):
a = np.random.random((10, 100))
b = np.random.random((10, 100))
res = np.add(a, b)
args = (a, b)
out = dist.rpc.rpc_async(worker_name(0), paddle_add, args=args).wait()
np.testing.assert_allclose(out, res, rtol=1e-05)
def test_get_worker_info(self):
info = dist.rpc.get_worker_info(worker_name(0))
self.assertEqual(info.name, worker_name(0))
self.assertEqual(info.rank, 0)
def test_get_all_worker_infos(self):
infos = dist.rpc.get_all_worker_infos()
info = infos[0]
self.assertEqual(info.name, worker_name(0))
self.assertEqual(info.rank, 0)
def test_get_current_worker_info(self):
info = dist.rpc.get_current_worker_info()
self.assertEqual(info.name, worker_name(0))
self.assertEqual(info.rank, 0)
class RpcLaunchTest(RpcLaunchTestBase):
def test_sync_rpc_paddle_add1(self):
nnodes = 2
nproc_per_node = 1
pwd, _ = os.path.split(os.path.realpath(__file__))
model_file = os.path.join(pwd, "rpc_launch_sync_add.py")
a, b = self.create_data(nnodes, nproc_per_node)
res = np.add(a, b)
out = self.launch_rpc(nnodes, nproc_per_node, model_file)
np.testing.assert_allclose(out, res, rtol=1e-05)
def test_sync_rpc_paddle_add2(self):
nnodes = 2
nproc_per_node = 2
pwd, _ = os.path.split(os.path.realpath(__file__))
model_file = os.path.join(pwd, "rpc_launch_sync_add.py")
a, b = self.create_data(nnodes, nproc_per_node)
res = np.add(a, b)
out = self.launch_rpc(nnodes, nproc_per_node, model_file)
np.testing.assert_allclose(out, res, rtol=1e-05)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from multiprocessing import Process, Queue
import subprocess
import socket
from contextlib import closing
import paddle.distributed as dist
import numpy as np
def worker_name(rank):
return "worker{}".format(rank)
def run_rpc_sync(
rank,
world_size,
master_endpoint,
queue,
fn,
args=None,
kwargs=None,
):
dist.rpc.init_rpc(
worker_name(rank),
rank,
world_size,
master_endpoint,
)
res = dist.rpc.rpc_sync(worker_name(0), fn, args=args, kwargs=kwargs)
queue.put(res)
dist.rpc.shutdown()
def run_rpc_sync_master_working(
rank,
world_size,
master_endpoint,
queue,
fn,
args=None,
kwargs=None,
):
dist.rpc.init_rpc(
worker_name(rank),
rank,
world_size,
master_endpoint,
)
if dist.get_rank() == 0:
for i in range(1, dist.get_rank()):
res = dist.rpc.rpc_sync(worker_name(i),
fn,
args=args,
kwargs=kwargs)
queue.put(res)
dist.rpc.shutdown()
def run_rpc_async(
rank,
world_size,
master_endpoint,
queue,
fn,
args=None,
kwargs=None,
):
dist.rpc.init_rpc(
worker_name(rank),
rank,
world_size,
master_endpoint,
)
res = dist.rpc.rpc_async(worker_name(0), fn, args=args, kwargs=kwargs)
queue.put(res.wait())
dist.rpc.shutdown()
def run_rpc_async_master_working(
rank,
world_size,
master_endpoint,
queue,
fn,
args=None,
kwargs=None,
):
dist.rpc.init_rpc(
worker_name(rank),
rank,
world_size,
master_endpoint,
)
if dist.get_rank() == 0:
for i in range(1, dist.get_rank()):
res = dist.rpc.rpc_async(worker_name(i),
fn,
args=args,
kwargs=kwargs)
queue.put(res.wait())
dist.rpc.shutdown()
class RpcTestBase(unittest.TestCase):
def setUp(self):
self._port_set = set()
print("RPC setUp...")
def tearDown(self):
if len(self.processes) != 0:
[p.join() for p in self.processes]
print("RPC tearDown...")
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(("", 0))
return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def run_rpc(self, sync, world_size, fn, fn_args=None, fn_kwargs=None):
self.processes = []
queues = []
master_endpoint = "127.0.0.1:{}".format(self._find_free_port())
for rank in range(world_size):
q = Queue()
queues.append(q)
if sync:
self.processes.append(
Process(
target=run_rpc_sync,
args=(
rank,
world_size,
master_endpoint,
q,
fn,
fn_args,
fn_kwargs,
),
))
else:
self.processes.append(
Process(
target=run_rpc_async,
args=(
rank,
world_size,
master_endpoint,
q,
fn,
fn_args,
fn_kwargs,
),
))
[p.start() for p in self.processes]
return queues
class RpcLaunchTestBase(unittest.TestCase):
def setUp(self):
self._port_set = set()
print("Launch RPC setUp...")
def tearDown(self):
self.remove_data()
print("Launch RPC tearDown...")
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(("", 0))
return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def create_data(self, nnodes, nproc_per_node):
mmap_data1 = np.memmap(
"rpc_launch_data1.npy",
dtype=np.float32,
mode="w+",
shape=(10 * nnodes * nproc_per_node, 100),
)
mmap_data2 = np.memmap(
"rpc_launch_data2.npy",
dtype=np.float32,
mode="w+",
shape=(10 * nnodes * nproc_per_node, 100),
)
for i in range(nnodes * nproc_per_node):
a = np.random.random((10, 100)).astype(np.float32)
b = np.random.random((10, 100)).astype(np.float32)
mmap_data1[i * 10:(i + 1) * 10, :] = a
mmap_data2[i * 10:(i + 1) * 10, :] = b
return mmap_data1, mmap_data2
def remove_data(self):
os.remove("rpc_launch_data1.npy")
os.remove("rpc_launch_data2.npy")
def launch_rpc(self, nnodes, nproc_per_node, model_file):
master_endpoint = "127.0.0.1:{}".format(self._find_free_port())
log_dir = "log"
tr_cmd = "python -m paddle.distributed.launch --master {} --rank {} --nnodes {} --nproc_per_node {} --run_mode rpc {} --log_dir {}"
cmds = [
tr_cmd.format(master_endpoint, rank, nnodes, nproc_per_node,
model_file, log_dir) for rank in range(nnodes)
]
processes = [subprocess.Popen(cmd.strip().split()) for cmd in cmds]
[proc.communicate() for proc in processes]
out = np.memmap(
"rpc_launch_result.npy",
dtype=np.float32,
mode="r",
shape=(10 * nnodes * nproc_per_node, 100),
)
os.remove("rpc_launch_result.npy")
import shutil
shutil.rmtree(log_dir)
return out
......@@ -304,6 +304,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_optimizers.ascend',
'paddle.distributed.fleet.meta_optimizers.dygraph_optimizer',
'paddle.distributed.fleet.runtime',
'paddle.distributed.rpc',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
'paddle.distributed.fleet.metrics',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册