diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index cfb23d1be2acfed0a878cb3bffa241afa2cf3de8..81b2b0a12b2c37b7e9f36aa2df3e8bc5013aacd1 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -209,6 +209,15 @@ class PSGPUWrapper { void EndPass() { HeterPs_->end_pass(); } void ShowOneTable(int index) { HeterPs_->show_one_table(index); } + void Finalize() { + VLOG(3) << "PSGPUWrapper Begin Finalize."; + if (s_instance_ == nullptr) { + return; + } + s_instance_ = nullptr; + VLOG(3) << "PSGPUWrapper Finalize Finished."; + } + private: static std::shared_ptr s_instance_; Dataset* dataset_; diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index 0c239f8157e5dff03ba71bb018c77b7b5a4b86a6..bdd7abe1d8332a702bef3de2c4948c1f94f7a85f 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -48,6 +48,8 @@ void BindPSGPUWrapper(py::module* m) { .def("end_pass", &framework::PSGPUWrapper::EndPass, py::call_guard()) .def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS, + py::call_guard()) + .def("finalize", &framework::PSGPUWrapper::Finalize, py::call_guard()); } // end PSGPUWrapper #endif