From dfcfc8b726410d3cf4c2591392051ee6d4dab244 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Tue, 29 Aug 2023 10:24:21 +0800 Subject: [PATCH] fix memory leak (#56705) --- .../distributed/fleet/layers/mpu/mp_ops.py | 114 +++++++++--------- 1 file changed, 55 insertions(+), 59 deletions(-) diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index c18ef35e189..26949b12e42 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -14,6 +14,7 @@ import paddle from paddle import _legacy_C_ops +from paddle.autograd import PyLayer from paddle.distributed import collective from paddle.fluid import core from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype @@ -24,6 +25,59 @@ from paddle.nn.utils import dygraph_utils from ....communication.reduce import ReduceOp, _get_reduce_op +class c_identity_eager(PyLayer): + @staticmethod + def forward(ctx, tensor, group): + ctx.group = group + return _legacy_C_ops.c_identity( + tensor, + 'use_calc_stream', + True, + 'ring_id', + group.id, + 'use_model_parallel', + True, + ) + + @staticmethod + def backward(ctx, dy): + op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") + group = ctx.group + group.process_group.all_reduce_on_calc_stream(dy, op_type) + return dy + + +class mp_allreduce_eager(PyLayer): + @staticmethod + def forward(ctx, tensor, group, use_calc_stream, use_model_parallel): + ctx.ring_id = group.id + + if use_calc_stream: + op_type = _get_reduce_op(ReduceOp.SUM, "_mp_allreduce") + group.process_group.all_reduce_on_calc_stream(tensor, op_type) + return tensor + else: + return _legacy_C_ops.c_allreduce_sum_( + tensor, + 'use_calc_stream', + use_calc_stream, + 'ring_id', + ctx.ring_id, + ) + + @staticmethod + def backward(ctx, dy): + return _legacy_C_ops.c_identity( + dy, + 'use_calc_stream', + True, + 'ring_id', + ctx.ring_id, + 'use_model_parallel', + True, + ) + + def _c_identity(tensor, group=None): """ Return a copy of the tensor, mainly used with model parallel. @@ -41,28 +95,7 @@ def _c_identity(tensor, group=None): ring_id = 0 if group is None else group.id if in_dygraph_mode(): - from paddle.autograd import PyLayer - - class c_identity_eager(PyLayer): - @staticmethod - def forward(ctx, tensor): - return _legacy_C_ops.c_identity( - tensor, - 'use_calc_stream', - True, - 'ring_id', - group.id, - 'use_model_parallel', - True, - ) - - @staticmethod - def backward(ctx, dy): - op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") - group.process_group.all_reduce_on_calc_stream(dy, op_type) - return dy - - return c_identity_eager.apply(tensor) + return c_identity_eager.apply(tensor, group) else: op_type = 'c_identity' helper = LayerHelper(op_type, **locals()) @@ -230,43 +263,6 @@ def _mp_allreduce( if in_dygraph_mode(): group = collective._get_default_group() if group is None else group assert op == ReduceOp.SUM, f"Unknown parameter: {op}." - - from paddle.autograd import PyLayer - - class mp_allreduce_eager(PyLayer): - @staticmethod - def forward( - ctx, tensor, group, use_calc_stream, use_model_parallel - ): - ctx.ring_id = group.id - - if use_calc_stream: - op_type = _get_reduce_op(op, "_mp_allreduce") - group.process_group.all_reduce_on_calc_stream( - tensor, op_type - ) - return tensor - else: - return _legacy_C_ops.c_allreduce_sum_( - tensor, - 'use_calc_stream', - use_calc_stream, - 'ring_id', - ring_id, - ) - - @staticmethod - def backward(ctx, dy): - return _legacy_C_ops.c_identity( - dy, - 'use_calc_stream', - True, - 'ring_id', - ctx.ring_id, - 'use_model_parallel', - True, - ) - return mp_allreduce_eager.apply( tensor, group, use_calc_stream, use_model_parallel ) -- GitLab