未验证 提交 28cb0067 编写于 作者: Z zmxdream 提交者: GitHub

[GPUPS]FleetWrapper initialize (#44441)

* fix FleetWrapper initialize
上级 0e2dd2f3
......@@ -37,6 +37,7 @@ namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
std::mutex FleetWrapper::ins_mutex;
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include <atomic>
#include <ctime>
#include <map>
#include <mutex>
#include <random>
#include <string>
#include <unordered_map>
......@@ -381,8 +382,11 @@ class FleetWrapper {
void Revert();
// FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
{
std::lock_guard<std::mutex> lk(ins_mutex);
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
}
}
return s_instance_;
}
......@@ -397,6 +401,7 @@ class FleetWrapper {
private:
static std::shared_ptr<FleetWrapper> s_instance_;
static std::mutex ins_mutex;
#ifdef PADDLE_WITH_PSLIB
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
#endif
......
......@@ -41,8 +41,9 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindFleetWrapper(py::module* m) {
py::class_<framework::FleetWrapper>(*m, "Fleet")
.def(py::init())
py::class_<framework::FleetWrapper, std::shared_ptr<framework::FleetWrapper>>(
*m, "Fleet")
.def(py::init([]() { return framework::FleetWrapper::GetInstance(); }))
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("pull_dense", &framework::FleetWrapper::PullDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer)
......
......@@ -88,7 +88,8 @@ class FleetTranspiler(Fleet):
if role_maker is None:
role_maker = MPISymetricRoleMaker()
super(FleetTranspiler, self).init(role_maker)
self._fleet_ptr = core.Fleet()
if self._fleet_ptr is None:
self._fleet_ptr = core.Fleet()
def _init_transpiler_worker(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册