diff --git a/paddle/fluid/distributed/collective/process_group.cc b/paddle/fluid/distributed/collective/process_group.cc index 2722edf8deac66772ff44f023359cd6375b454cb..4963fe0453ac2d35c1f4254e279360b130e9770d 100644 --- a/paddle/fluid/distributed/collective/process_group.cc +++ b/paddle/fluid/distributed/collective/process_group.cc @@ -36,11 +36,15 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() { return instance; } -void ProcessGroupIdMap::DestroyProcessGroup(int gid) { - int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count(); - for (int i = 0; i < use_count; ++i) { - ProcessGroupIdMap::GetInstance()[gid].reset(); +void ProcessGroupIdMap::DestroyProcessGroup() { + auto& id_map = ProcessGroupIdMap::GetInstance(); + for (auto iter = id_map.begin(); iter != id_map.end(); ++iter) { + auto use_count = iter->second.use_count(); + for (int i = 0; i < use_count; ++i) { + iter->second.reset(); + } } + id_map.clear(); } } // namespace distributed diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index de67eaf2a5e875a598cb68ce004874e697542e4d..df9cfdfbc794f0f290350ee27e5eaac0d9b1c571 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -502,7 +502,7 @@ class ProcessGroupIdMap : public std::unordered_map> { public: static ProcessGroupIdMap& GetInstance(); - static void DestroyProcessGroup(int gid); + static void DestroyProcessGroup(); }; // TODO(dev): The following method will be removed soon. diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a30909322eccfe9333ed3ab308d1b136ce623b58..a81e221126214da6e1c80e8931f2a45de7f0e931 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -215,7 +215,7 @@ endif() copy_if_different(${pybind_file} ${pybind_file_final}) if (WITH_CUSTOM_DEVICE) -cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator) +cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator phi_api) endif() if(NOT "${OP_LIST}" STREQUAL "") diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index e59b308d93572d857e7015270490a87937f798a5..c8f829a2f017f46cf06d5e12770d0d4c2816182e 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -13,11 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/custom_device_common_op_registry.h" +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/operators/collective/c_concat_op.h" #include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/save_combine_op.h" +#include "paddle/phi/api/backward/backward_api.h" +#include "paddle/phi/api/include/api.h" #include "paddle/phi/backends/device_manager.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" #define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \ static paddle::framework::OpKernelRegistrar \ @@ -43,6 +48,443 @@ limitations under the License. */ namespace paddle { namespace operators { +template +class CConcatOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + + int nranks = ctx.Attr("nranks"); + int rank = ctx.Attr("rank"); + int rid = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE(rank, + 0, + platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_concat must be " + "greater than or equal to 0.", + rank)); + PADDLE_ENFORCE_GE(nranks, + 2, + platform::errors::PreconditionNotMet( + "The value of nranks (%d) for c_concat must be " + "greater than or equal to 2.", + nranks)); + PADDLE_ENFORCE_LT(rank, + nranks, + platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_concat must be " + "less than that of nranks (%d).", + rank, + nranks)); + + auto& dev_ctx = ctx.template device_context(); + phi::DenseTensor temp_out; + framework::DDim temp_out_dims = x->dims(); + temp_out_dims[0] *= nranks; + temp_out.Resize(temp_out_dims); + dev_ctx.template Alloc(&temp_out); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(rid)) { + // Use ProcessGroup + distributed::ProcessGroup* pg = map->get(rid); + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(*x); + out_tensor.push_back(temp_out); + auto task = pg->AllGather(in_tensor, out_tensor); + task->Wait(); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "CustomDevice c_concat only support ProcessGroup")); + } + std::vector inputs; + int axis = x->dims().size() - 1; + auto out_dims = x->dims(); + out_dims[out_dims.size() - 1] *= nranks; + int rows_per_tensor = x->dims()[0]; + int offset = 0; + for (int i = 0; i < nranks; i++) { + phi::DenseTensor temp = temp_out.Slice(offset, offset + rows_per_tensor); + inputs.emplace_back(temp); + offset += rows_per_tensor; + } + + out->Resize(out_dims); + std::vector inputs_t(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + auto t = std::make_shared(); + t->ShareDataWith(inputs[i]); + inputs_t[i].set_impl(t); + } + auto output = paddle::experimental::concat(inputs_t, axis); + out->ShareDataWith( + *reinterpret_cast(output.impl().get())); + } +}; + +template +class CSplitOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + + int nranks = ctx.Attr("nranks"); + int rank = ctx.Attr("rank"); + + PADDLE_ENFORCE_GE(rank, + 0, + platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "greater than or equal to 0.", + rank)); + PADDLE_ENFORCE_GE(nranks, + 2, + platform::errors::PreconditionNotMet( + "The value of nranks (%d) for c_split must be " + "greater than or equal to 2.", + nranks)); + PADDLE_ENFORCE_LT(rank, + nranks, + platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "less than that of nranks (%d).", + rank, + nranks)); + + auto dims = x->dims(); + auto dims_size = dims.size(); + + dims[dims_size - 1] /= nranks; + out->Resize(dims); + std::vector split_list(nranks, dims[dims_size - 1]); + int axis = dims_size - 1; + + auto x_tmp = std::make_shared(); + x_tmp->ShareDataWith(*x); + paddle::Tensor x_tensor(x_tmp); + auto outputs = paddle::experimental::split(x_tensor, split_list, axis); + out->ShareDataWith( + *reinterpret_cast(outputs[rank].impl().get())); + } +}; + +template +class CEmbeddingOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* ids_t = ctx.Input("Ids"); + auto* table_t = ctx.Input("W"); + auto* output_t = ctx.Output("Out"); + auto out_dims = output_t->dims(); + auto start_index = ctx.Attr("start_index"); + + auto K = ids_t->numel(); + auto N = table_t->dims()[0]; + auto D = table_t->dims()[1]; + auto index_type = ids_t->dtype(); + if (index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64) { + auto x_tmp = std::make_shared(); + x_tmp->ShareDataWith(*ids_t).Resize({K}); + auto w_tmp = std::make_shared(); + w_tmp->ShareDataWith(*table_t).Resize({N, D}); + paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp); + auto start_index_tensor = paddle::experimental::full_like( + x_tensor, start_index, x_tensor.dtype(), x_tensor.place()); + auto end_index_tensor = paddle::experimental::full_like( + x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place()); + auto ids_mask_tensor = paddle::experimental::logical_and( + x_tensor.greater_equal(start_index_tensor), + x_tensor.less_than(end_index_tensor)); + auto ids_tensor = (x_tensor - start_index_tensor) + .multiply(paddle::experimental::cast( + ids_mask_tensor, x_tensor.dtype())); + auto out_tensor = + paddle::experimental::reshape( + paddle::experimental::cast(ids_mask_tensor, w_tensor.dtype()), + {K, 1}) + .multiply(paddle::experimental::reshape( + paddle::experimental::embedding( + ids_tensor, w_tensor, -1, false), + {K, D})); + output_t + ->ShareDataWith( + *reinterpret_cast(out_tensor.impl().get())) + .Resize(out_dims); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "CustomDevice c_embedding ids only support int32 or int64.")); + } + } +}; + +template +class CEmbeddingGradOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto start_index = ctx.Attr("start_index"); + auto ids_t = ctx.Input("Ids"); + auto d_output_t = + ctx.Input(framework::GradVarName("Out")); + auto table_t = ctx.Input("W"); + auto table_grad_t = + ctx.Output(framework::GradVarName("W")); + table_grad_t->Resize(table_t->dims()); + auto& dev_ctx = ctx.template device_context(); + + auto K = ids_t->numel(); + auto N = table_t->dims()[0]; + auto D = table_t->dims()[1]; + const auto& index_type = ids_t->dtype(); + if (index_type == phi::DataType::INT32 || + index_type == phi::DataType::INT64) { + auto x_tmp = std::make_shared(); + x_tmp->ShareDataWith(*ids_t).Resize({K}); + auto w_tmp = std::make_shared(); + w_tmp->set_meta(table_t->meta()); + dev_ctx.Alloc(w_tmp.get(), w_tmp->dtype()); + auto out_grad_tmp = std::make_shared(); + out_grad_tmp->ShareDataWith(*d_output_t).Resize({K, D}); + paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp), + out_grad_tensor(out_grad_tmp); + auto start_index_tensor = paddle::experimental::full_like( + x_tensor, start_index, x_tensor.dtype(), x_tensor.place()); + auto end_index_tensor = paddle::experimental::full_like( + x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place()); + auto ids_mask_tensor = paddle::experimental::logical_and( + x_tensor.greater_equal(start_index_tensor), + x_tensor.less_equal(end_index_tensor)); + auto real_ids_tensor = (x_tensor - start_index_tensor) + .multiply(paddle::experimental::cast( + ids_mask_tensor, x_tensor.dtype())); + auto out_grad_tensor_mul_mask = + paddle::experimental::reshape(out_grad_tensor, {K, D}) + .multiply(paddle::experimental::reshape( + paddle::experimental::cast(ids_mask_tensor, table_t->dtype()), + {K, 1})); + paddle::Tensor table_grad_tensor; + paddle::experimental::embedding_grad(real_ids_tensor, + w_tensor, + out_grad_tensor_mul_mask, + -1, + false, + &table_grad_tensor); + table_grad_t->ShareDataWith( + *reinterpret_cast(table_grad_tensor.impl().get())); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "CustomDevice c_embedding ids only support int32 or int64.")); + } + } +}; + +template +class CSoftmaxWithCrossEntropyOpCustomDeviceKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const int rid = ctx.Attr("ring_id"); + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(rid)) { + const phi::DenseTensor* logits = ctx.Input("Logits"); + const phi::DenseTensor* labels = ctx.Input("Label"); + phi::DenseTensor* softmax = ctx.Output("Softmax"); + phi::DenseTensor* loss = ctx.Output("Loss"); + auto softmax_dims = softmax->dims(); + auto loss_dims = loss->dims(); + + const int64_t ignore_index = ctx.Attr("ignore_index"); + PADDLE_ENFORCE_LT(ignore_index, + 0, + platform::errors::InvalidArgument( + "When SoftmaxWithCrossEntropy run on CustomDevice, " + "ignore_index should be <=0, however it's %ld", + ignore_index)); + const int rid = ctx.Attr("ring_id"); + const int rank = ctx.Attr("rank"); + + distributed::ProcessGroup* pg = map->get(rid); + distributed::AllreduceOptions opts; + + // allocate memory on device. + const auto& logits_dims = logits->dims(); + + const int axis = logits_dims.size() - 1; + const int N = phi::funcs::SizeToAxis(axis, logits_dims); + const int D = phi::funcs::SizeFromAxis(axis, logits_dims); + + auto logits_2d = std::make_shared(); + auto labels_1d = std::make_shared(); + logits_2d->ShareDataWith(*logits).Resize({N, D}); + labels_1d->ShareDataWith(*labels).Resize({N}); + paddle::Tensor logits_2d_tensor(logits_2d), labels_1d_tensor(labels_1d); + + // step 1, obtain logit_max + auto logits_2d_max_tensor = logits_2d_tensor.max({1}, true); + std::vector in_out; + in_out.push_back(*reinterpret_cast( + logits_2d_max_tensor.impl().get())); + opts.reduce_op = distributed::ReduceOp::MAX; + pg->AllReduce(in_out, in_out, opts)->Synchronize(); + + // step 2, obtain logit - logit_max + auto logits_2d_sub_max = paddle::experimental::clip( + logits_2d_tensor - logits_2d_max_tensor, -64., 0.); + + // step 3, obtain predict target + const int start_index = rank * D; + auto start_index_tensor = + paddle::experimental::full_like(labels_1d_tensor, + start_index, + labels_1d_tensor.dtype(), + labels_1d_tensor.place()); + auto end_index_tensor = + paddle::experimental::full_like(labels_1d_tensor, + start_index + D, + labels_1d_tensor.dtype(), + labels_1d_tensor.place()); + auto labels_1d_mask = paddle::experimental::logical_and( + labels_1d_tensor.greater_equal(start_index_tensor), + labels_1d_tensor.less_than(end_index_tensor)); + auto real_label_tensor = + (labels_1d_tensor - start_index_tensor) + .multiply(paddle::experimental::cast(labels_1d_mask, + labels_1d_tensor.dtype())); + + auto predicted_logits_tensor = + logits_2d_sub_max + .multiply(paddle::experimental::cast( + paddle::experimental::one_hot(real_label_tensor, D), + logits_2d_sub_max.dtype())) + .sum({1}, logits_2d_sub_max.dtype(), false) + .multiply(paddle::experimental::cast(labels_1d_mask, + logits_2d_sub_max.dtype())); + + in_out.clear(); + in_out.push_back(*reinterpret_cast( + predicted_logits_tensor.impl().get())); + opts.reduce_op = distributed::ReduceOp::SUM; + pg->AllReduce(in_out, in_out, opts)->Synchronize(); + + // step 4, obtain exp(logit) + auto softmax_2d_tensor = logits_2d_sub_max.exp(); + + // step 5, obtain sum_exp_logits + auto sum_exp_logits_tensor = + softmax_2d_tensor.sum({1}, softmax_2d_tensor.dtype(), false); + + in_out.clear(); + in_out.push_back(*reinterpret_cast( + sum_exp_logits_tensor.impl().get())); + opts.reduce_op = distributed::ReduceOp::SUM; + pg->AllReduce(in_out, in_out, opts)->Synchronize(); + + auto softmax_out = softmax_2d_tensor.divide( + paddle::experimental::reshape(sum_exp_logits_tensor, {N, 1})); + auto labels_1d_not_equal_ignore = labels_1d_tensor.not_equal( + paddle::experimental::full_like(labels_1d_tensor, + ignore_index, + labels_1d_tensor.dtype(), + labels_1d_tensor.place())); + auto loss_out = + (sum_exp_logits_tensor.log() - predicted_logits_tensor) + .multiply(paddle::experimental::cast( + labels_1d_not_equal_ignore, sum_exp_logits_tensor.dtype())); + softmax + ->ShareDataWith( + *reinterpret_cast(softmax_out.impl().get())) + .Resize(softmax_dims); + loss->ShareDataWith( + *reinterpret_cast(loss_out.impl().get())) + .Resize(loss_dims); + } else { + PADDLE_THROW( + phi::errors::Unavailable("CustomDevice c_softmax_with_cross_entropy " + "only support ProcessGroup")); + } + } +}; + +template +class CSoftmaxWithCrossEntropyGradCustomDeviceKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const phi::DenseTensor* labels = context.Input("Label"); + const phi::DenseTensor* loss_grad = + context.Input(framework::GradVarName("Loss")); + const phi::DenseTensor* softmax = + context.Input("Softmax"); + phi::DenseTensor* logit_grad = + context.Output(framework::GradVarName("Logits")); + + const int64_t ignore_index = context.Attr("ignore_index"); + const int rank = context.Attr("rank"); + if (logit_grad != softmax) { + framework::TensorCopy( + *softmax, context.GetPlace(), context.device_context(), logit_grad); + } + const auto sofrmax_dims = softmax->dims(); + const int axis = sofrmax_dims.size() - 1; + const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims); + const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims); + const auto& label_type = labels->dtype(); + + if (label_type == phi::DataType::INT32 || + label_type == phi::DataType::INT64) { + auto logit_grad_t = std::make_shared(); + logit_grad_t->ShareDataWith(*logit_grad).Resize({N, D}); + auto loss_grad_t = std::make_shared(); + loss_grad_t->ShareDataWith(*loss_grad).Resize({N}); + auto labels_1d = std::make_shared(); + labels_1d->ShareDataWith(*labels).Resize({N}); + paddle::Tensor logits_grad_tensor(logit_grad_t), + loss_grad_tensor(loss_grad_t), labels_1d_tensor(labels_1d); + + auto labels_1d_not_equal_ignore = paddle::experimental::reshape( + paddle::experimental::not_equal( + labels_1d_tensor, + paddle::experimental::full_like(labels_1d_tensor, + ignore_index, + labels_1d_tensor.dtype(), + labels_1d_tensor.place())), + {N, 1}); + auto start_index_tensor = + paddle::experimental::full_like(labels_1d_tensor, + rank * D, + labels_1d_tensor.dtype(), + labels_1d_tensor.place()); + + auto logits_grad_out_tensor1 = paddle::experimental::subtract( + paddle::experimental::multiply( + logits_grad_tensor, + paddle::experimental::cast(labels_1d_not_equal_ignore, + logits_grad_tensor.dtype())), + paddle::experimental::cast( + paddle::experimental::one_hot( + paddle::experimental::subtract(labels_1d_tensor, + start_index_tensor), + D), + logits_grad_tensor.dtype())); + + auto logits_grad_out_tensor2 = paddle::experimental::multiply( + logits_grad_out_tensor1, + paddle::experimental::reshape(loss_grad_tensor, {N, 1})); + logit_grad + ->ShareDataWith(*reinterpret_cast( + logits_grad_out_tensor2.impl().get())) + .Resize(sofrmax_dims); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "CustomDevice c_softmax_with_cross_entropy_grad " + "label_type only support int32/int64")); + } + } +}; + template void FeedDenseTensorKernel(const Context& dev_ctx, const phi::ExtendedTensor& x, @@ -87,6 +529,66 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { LoadCombineOpKernel, paddle::operators:: LoadCombineOpKernel); + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_concat, + device_type, + paddle::operators::CConcatOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float>, + paddle::operators::CConcatOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16>); + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_split, + device_type, + paddle::operators::CSplitOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float>, + paddle::operators::CSplitOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int>, + paddle::operators::CSplitOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16>); + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_embedding, + device_type, + paddle::operators::CEmbeddingOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float>); + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_embedding_grad, + device_type, + paddle::operators::CEmbeddingGradOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float>); + + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_softmax_with_cross_entropy, + device_type, + paddle::operators::CSoftmaxWithCrossEntropyOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float>, + paddle::operators::CSoftmaxWithCrossEntropyOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + double>, + paddle::operators::CSoftmaxWithCrossEntropyOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16>) {} + + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_softmax_with_cross_entropy_grad, + device_type, + paddle::operators::CSoftmaxWithCrossEntropyGradCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float>, + paddle::operators::CSoftmaxWithCrossEntropyGradCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + double>, + paddle::operators::CSoftmaxWithCrossEntropyGradCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16>) {} + #endif } diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 46d690c69a0525682129f64c634591a8bdd16c99..01df736fb10b34d54c6cd93842eaacab71787b7a 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1357,7 +1357,6 @@ void BindDistributed(py::module *m) { *m, "ProcessGroupIdMap") .def_static("destroy", distributed::ProcessGroupIdMap::DestroyProcessGroup, - py::arg("group_id") = 0, py::call_guard()); } diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index e86b1bc32ec6f2de30832f0ffce992e058377b35..8f6237bfa4c4b1e2df7a87a04d6135e4fb69a59d 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -32,7 +32,6 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 from .collective import split # noqa: F401 from .collective import new_group # noqa: F401 from .collective import is_available -from .collective import _destroy_process_group_id_map from .communication import ( stream, ReduceOp, @@ -122,5 +121,3 @@ __all__ = [ # noqa "is_available", "get_backend", ] - -atexit.register(_destroy_process_group_id_map) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 9b0eb8b0895e8e0fcfe366a7ed952a9f4bd0d375..e322b34575d5fdbcc82f27d9a50b0f8dd7bb634c 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -172,16 +172,6 @@ def _set_custom_gid(gid): _custom_gid = gid -def _destroy_process_group_id_map(): - """ - - Destroy the custom process group. Designed for CustomDevice. - - - """ - core.ProcessGroupIdMap.destroy() - - def new_group(ranks=None, backend=None, timeout=_default_timeout): """ diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 15ea4367e5e42c2874ab9577f63ba45f65492f29..5df4a15efc7d3106104ef4500642a4b726d6ed65 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -223,3 +223,4 @@ atexit.register(core.clear_executor_cache) # Keep clear_kernel_factory running before clear_device_manager atexit.register(core.clear_device_manager) atexit.register(core.clear_kernel_factory) +atexit.register(core.ProcessGroupIdMap.destroy)