未验证 提交 28cb0067 编写于 作者: Z zmxdream 提交者: GitHub

[GPUPS]FleetWrapper initialize (#44441)

* fix FleetWrapper initialize
上级 0e2dd2f3
...@@ -37,6 +37,7 @@ namespace framework { ...@@ -37,6 +37,7 @@ namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL; std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false; bool FleetWrapper::is_initialized_ = false;
std::mutex FleetWrapper::ins_mutex;
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL; std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include <atomic> #include <atomic>
#include <ctime> #include <ctime>
#include <map> #include <map>
#include <mutex>
#include <random> #include <random>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -381,8 +382,11 @@ class FleetWrapper { ...@@ -381,8 +382,11 @@ class FleetWrapper {
void Revert(); void Revert();
// FleetWrapper singleton // FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() { static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) { {
s_instance_.reset(new paddle::framework::FleetWrapper()); std::lock_guard<std::mutex> lk(ins_mutex);
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
}
} }
return s_instance_; return s_instance_;
} }
...@@ -397,6 +401,7 @@ class FleetWrapper { ...@@ -397,6 +401,7 @@ class FleetWrapper {
private: private:
static std::shared_ptr<FleetWrapper> s_instance_; static std::shared_ptr<FleetWrapper> s_instance_;
static std::mutex ins_mutex;
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions; std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
#endif #endif
......
...@@ -41,8 +41,9 @@ namespace py = pybind11; ...@@ -41,8 +41,9 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
void BindFleetWrapper(py::module* m) { void BindFleetWrapper(py::module* m) {
py::class_<framework::FleetWrapper>(*m, "Fleet") py::class_<framework::FleetWrapper, std::shared_ptr<framework::FleetWrapper>>(
.def(py::init()) *m, "Fleet")
.def(py::init([]() { return framework::FleetWrapper::GetInstance(); }))
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync) .def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("pull_dense", &framework::FleetWrapper::PullDenseVarsSync) .def("pull_dense", &framework::FleetWrapper::PullDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer) .def("init_server", &framework::FleetWrapper::InitServer)
......
...@@ -88,7 +88,8 @@ class FleetTranspiler(Fleet): ...@@ -88,7 +88,8 @@ class FleetTranspiler(Fleet):
if role_maker is None: if role_maker is None:
role_maker = MPISymetricRoleMaker() role_maker = MPISymetricRoleMaker()
super(FleetTranspiler, self).init(role_maker) super(FleetTranspiler, self).init(role_maker)
self._fleet_ptr = core.Fleet() if self._fleet_ptr is None:
self._fleet_ptr = core.Fleet()
def _init_transpiler_worker(self): def _init_transpiler_worker(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册