未验证 提交 3858f458 编写于 作者: S ShenLiang 提交者: GitHub

rm Singleton of reducer (#30775)

上级 2c974cc3
...@@ -41,8 +41,6 @@ namespace paddle { ...@@ -41,8 +41,6 @@ namespace paddle {
namespace imperative { namespace imperative {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
std::shared_ptr<Reducer> Reducer::s_instance_ = NULL;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
static void ConcatTensorsForAllReduce( static void ConcatTensorsForAllReduce(
const DeviceContext &context, const DeviceContext &context,
...@@ -225,14 +223,8 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, ...@@ -225,14 +223,8 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
}))); })));
var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index; var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index;
} }
std::call_once(once_flag_, []() {
std::atexit([]() { Reducer::GetInstance()->ReleaseReducer(); });
});
} }
void Reducer::ReleaseReducer() { parallel_ctx_.reset(); }
void Reducer::InitializeDenseGroups( void Reducer::InitializeDenseGroups(
const std::vector<size_t> &variable_indices_, Group *p_group) { const std::vector<size_t> &variable_indices_, Group *p_group) {
int64_t all_length = 0; int64_t all_length = 0;
......
...@@ -108,44 +108,16 @@ class Reducer { ...@@ -108,44 +108,16 @@ class Reducer {
void AddDistHook(size_t var_index); void AddDistHook(size_t var_index);
// void MarkDenseVarReady(size_t var_index);
// void MarkSparseVarReady(size_t var_index);
void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(size_t group_index); void MarkGroupReady(size_t group_index);
void FinalizeBackward(); void FinalizeBackward();
void ReleaseReducer();
std::vector<std::vector<size_t>> RebuildGruops(); std::vector<std::vector<size_t>> RebuildGruops();
inline bool NeedRebuildGroup() { return !has_rebuilt_group_; } inline bool NeedRebuildGroup() { return !has_rebuilt_group_; }
// Reducer Singleton
static std::shared_ptr<Reducer> SetInstance(
const std::vector<std::shared_ptr<imperative::VarBase>>& vars,
const std::vector<std::vector<size_t>>& group_indices,
const std::vector<bool>& is_sparse_gradient,
std::shared_ptr<imperative::ParallelContext> parallel_ctx,
const std::vector<size_t>& group_size_limits, bool find_unused_vars) {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::imperative::Reducer(
vars, group_indices, is_sparse_gradient, parallel_ctx,
group_size_limits, find_unused_vars));
}
return s_instance_;
}
static std::shared_ptr<Reducer> GetInstance() {
PADDLE_ENFORCE_EQ(
s_instance_ != NULL, true,
platform::errors::InvalidArgument("Reducer is not initialized."));
return s_instance_;
}
private: private:
std::vector<std::shared_ptr<imperative::VarBase>> vars_; std::vector<std::shared_ptr<imperative::VarBase>> vars_;
std::vector<std::vector<size_t>> group_indices_; std::vector<std::vector<size_t>> group_indices_;
......
...@@ -1390,16 +1390,11 @@ void BindImperative(py::module *m_ptr) { ...@@ -1390,16 +1390,11 @@ void BindImperative(py::module *m_ptr) {
py::class_<imperative::Reducer, std::shared_ptr<imperative::Reducer>>( py::class_<imperative::Reducer, std::shared_ptr<imperative::Reducer>>(
m, "Reducer", R"DOC()DOC") m, "Reducer", R"DOC()DOC")
.def(py::init([]( .def(py::init<const std::vector<std::shared_ptr<imperative::VarBase>> &,
const std::vector<std::shared_ptr<imperative::VarBase>> &vars, const std::vector<std::vector<size_t>> &,
const std::vector<std::vector<size_t>> &group_indices, const std::vector<bool> &,
const std::vector<bool> &is_sparse_gradient, std::shared_ptr<imperative::ParallelContext>,
std::shared_ptr<imperative::ParallelContext> parallel_ctx, const std::vector<size_t> &, bool>())
const std::vector<size_t> &group_size_limits, bool find_unused_vars) {
return imperative::Reducer::SetInstance(
vars, group_indices, is_sparse_gradient, parallel_ctx,
group_size_limits, find_unused_vars);
}))
.def("prepare_for_backward", &imperative::Reducer::PrepareForBackward, .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward,
py::arg("vars"), py::call_guard<py::gil_scoped_release>()); py::arg("vars"), py::call_guard<py::gil_scoped_release>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册