未验证 提交 ede8fd55 编写于 作者: W wanghuancoder 提交者: GitHub

fix pylayer py39 mem leak (#56623)

* fix pylayer py39 mem leak
上级 5bfcb501
...@@ -179,6 +179,136 @@ class VocabParallelEmbedding(paddle.nn.Layer): ...@@ -179,6 +179,136 @@ class VocabParallelEmbedding(paddle.nn.Layer):
return output return output
class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(
ctx,
x,
weight,
bias,
fuse_matmul_bias,
mp_async_allreduce,
mp_skip_c_identity,
mp_fused_linear_param_grad_add,
model_parallel_group,
):
ctx.save_for_backward(x, weight, bias)
ctx.model_parallel_group = model_parallel_group
ctx.mp_fused_linear_param_grad_add = mp_fused_linear_param_grad_add
if mp_skip_c_identity is False:
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
True,
'ring_id',
model_parallel_group.id,
'use_model_parallel',
True,
)
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(x, weight, bias)
@staticmethod
def backward(ctx, dy):
x, weight, bias = ctx.saved_tensor()
dx = paddle.matmul(dy, weight, transpose_y=True)
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
task = ctx.model_parallel_group.process_group.all_reduce(
dx, op_type, sync_op=False
)
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])
if ctx.mp_fused_linear_param_grad_add:
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set mp_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)
if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, False
)
task.wait()
return dx, dw
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dw = paddle.matmul(
x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if bias is None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias
class ColumnParallelLinear(paddle.nn.Layer): class ColumnParallelLinear(paddle.nn.Layer):
"""Linear layer with mp parallelized(column). """Linear layer with mp parallelized(column).
this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer. this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer.
...@@ -336,133 +466,16 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -336,133 +466,16 @@ class ColumnParallelLinear(paddle.nn.Layer):
# use inner api to process identity # use inner api to process identity
def _overlap_linear(): def _overlap_linear():
fuse_matmul_bias = self.fuse_matmul_bias return InnerOverlapLinear.apply(
mp_async_allreduce = self.mp_async_allreduce x,
mp_skip_c_identity = self.mp_skip_c_identity self.weight,
mp_fused_linear_param_grad_add = self.mp_fused_linear_param_grad_add self.bias,
self.fuse_matmul_bias,
class InnerOverlapLinear(paddle.autograd.PyLayer): self.mp_async_allreduce,
@staticmethod self.mp_skip_c_identity,
def forward(ctx, x, weight, bias): self.mp_fused_linear_param_grad_add,
ctx.save_for_backward(x, weight, bias) self.model_parallel_group,
if mp_skip_c_identity is False: )
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
True,
'ring_id',
self.model_parallel_group.id,
'use_model_parallel',
True,
)
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(
x, weight, bias
)
@staticmethod
def backward(ctx, dy):
x, weight, bias = ctx.saved_tensor()
dx = paddle.matmul(dy, weight, transpose_y=True)
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
task = self.model_parallel_group.process_group.all_reduce(
dx, op_type, sync_op=False
)
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])
if mp_fused_linear_param_grad_add:
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set mp_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)
if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, False
)
task.wait()
return dx, dw
if hasattr(weight, "main_grad") and hasattr(
bias, "main_grad"
):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dw = paddle.matmul(
x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if bias is None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias
return InnerOverlapLinear.apply(x, self.weight, self.bias)
if self.mp_async_allreduce: if self.mp_async_allreduce:
output_parallel = _overlap_linear() output_parallel = _overlap_linear()
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _legacy_C_ops
from paddle.autograd import PyLayer
from paddle.distributed import collective from paddle.distributed import collective
from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype
from paddle.framework import LayerHelper, _create_tensor, in_dynamic_mode from paddle.framework import LayerHelper, _create_tensor, in_dynamic_mode
...@@ -23,6 +24,30 @@ from paddle.nn.utils import dygraph_utils ...@@ -23,6 +24,30 @@ from paddle.nn.utils import dygraph_utils
from ....communication.reduce import ReduceOp, _get_reduce_op from ....communication.reduce import ReduceOp, _get_reduce_op
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor, group, skip_c_identity_dynamic):
ctx.group = group
if skip_c_identity_dynamic:
return tensor
else:
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
@staticmethod
def backward(ctx, dy):
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
ctx.group.process_group.all_reduce_on_calc_stream(dy, op_type)
return dy
def _c_identity(tensor, group=None, skip_c_identity_dynamic=False): def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
""" """
Return a copy of the tensor, mainly used with model parallel. Return a copy of the tensor, mainly used with model parallel.
...@@ -40,31 +65,7 @@ def _c_identity(tensor, group=None, skip_c_identity_dynamic=False): ...@@ -40,31 +65,7 @@ def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
if in_dynamic_mode(): if in_dynamic_mode():
from paddle.autograd import PyLayer return c_identity_eager.apply(tensor, group, skip_c_identity_dynamic)
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
if skip_c_identity_dynamic:
return tensor
else:
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
@staticmethod
def backward(ctx, dy):
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
group.process_group.all_reduce_on_calc_stream(dy, op_type)
return dy
return c_identity_eager.apply(tensor)
else: else:
op_type = 'c_identity' op_type = 'c_identity'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
...@@ -218,6 +219,49 @@ def _c_split(tensor, group=None): ...@@ -218,6 +219,49 @@ def _c_split(tensor, group=None):
return out return out
class mp_allreduce_eager(PyLayer):
@staticmethod
def forward(
ctx,
tensor,
group,
use_calc_stream,
use_model_parallel,
op,
skip_c_identity_dynamic,
):
ctx.ring_id = group.id
ctx.skip_c_identity_dynamic = skip_c_identity_dynamic
if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.all_reduce_on_calc_stream(tensor, op_type)
return tensor
else:
return _legacy_C_ops.c_allreduce_sum_(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
group.id,
)
@staticmethod
def backward(ctx, dy):
if ctx.skip_c_identity_dynamic:
return dy
else:
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
def _mp_allreduce( def _mp_allreduce(
tensor, tensor,
op=ReduceOp.SUM, op=ReduceOp.SUM,
...@@ -233,48 +277,13 @@ def _mp_allreduce( ...@@ -233,48 +277,13 @@ def _mp_allreduce(
if in_dynamic_mode(): if in_dynamic_mode():
group = collective._get_default_group() if group is None else group group = collective._get_default_group() if group is None else group
assert op == ReduceOp.SUM, f"Unknown parameter: {op}." assert op == ReduceOp.SUM, f"Unknown parameter: {op}."
from paddle.autograd import PyLayer
class mp_allreduce_eager(PyLayer):
@staticmethod
def forward(
ctx, tensor, group, use_calc_stream, use_model_parallel
):
ctx.ring_id = group.id
if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.all_reduce_on_calc_stream(
tensor, op_type
)
return tensor
else:
return _legacy_C_ops.c_allreduce_sum_(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
)
@staticmethod
def backward(ctx, dy):
if skip_c_identity_dynamic:
return dy
else:
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
return mp_allreduce_eager.apply( return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel tensor,
group,
use_calc_stream,
use_model_parallel,
op,
skip_c_identity_dynamic,
) )
else: else:
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册