未验证 提交 ede8fd55 编写于 作者: W wanghuancoder 提交者: GitHub

fix pylayer py39 mem leak (#56623)

* fix pylayer py39 mem leak
上级 5bfcb501
......@@ -179,6 +179,136 @@ class VocabParallelEmbedding(paddle.nn.Layer):
return output
class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(
ctx,
x,
weight,
bias,
fuse_matmul_bias,
mp_async_allreduce,
mp_skip_c_identity,
mp_fused_linear_param_grad_add,
model_parallel_group,
):
ctx.save_for_backward(x, weight, bias)
ctx.model_parallel_group = model_parallel_group
ctx.mp_fused_linear_param_grad_add = mp_fused_linear_param_grad_add
if mp_skip_c_identity is False:
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
True,
'ring_id',
model_parallel_group.id,
'use_model_parallel',
True,
)
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(x, weight, bias)
@staticmethod
def backward(ctx, dy):
x, weight, bias = ctx.saved_tensor()
dx = paddle.matmul(dy, weight, transpose_y=True)
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
task = ctx.model_parallel_group.process_group.all_reduce(
dx, op_type, sync_op=False
)
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])
if ctx.mp_fused_linear_param_grad_add:
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set mp_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)
if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, False
)
task.wait()
return dx, dw
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dw = paddle.matmul(
x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if bias is None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias
class ColumnParallelLinear(paddle.nn.Layer):
"""Linear layer with mp parallelized(column).
this class is used for splitting Linear Layer in mp group, column split the weight of the Linear layer.
......@@ -336,133 +466,16 @@ class ColumnParallelLinear(paddle.nn.Layer):
# use inner api to process identity
def _overlap_linear():
fuse_matmul_bias = self.fuse_matmul_bias
mp_async_allreduce = self.mp_async_allreduce
mp_skip_c_identity = self.mp_skip_c_identity
mp_fused_linear_param_grad_add = self.mp_fused_linear_param_grad_add
class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias):
ctx.save_for_backward(x, weight, bias)
if mp_skip_c_identity is False:
x = paddle._legacy_C_ops.c_identity(
return InnerOverlapLinear.apply(
x,
'use_calc_stream',
True,
'ring_id',
self.model_parallel_group.id,
'use_model_parallel',
True,
)
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(
x, weight, bias
)
@staticmethod
def backward(ctx, dy):
x, weight, bias = ctx.saved_tensor()
dx = paddle.matmul(dy, weight, transpose_y=True)
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
task = self.model_parallel_group.process_group.all_reduce(
dx, op_type, sync_op=False
)
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])
if mp_fused_linear_param_grad_add:
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set mp_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)
if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, False
)
task.wait()
return dx, dw
if hasattr(weight, "main_grad") and hasattr(
bias, "main_grad"
):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dw = paddle.matmul(
x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
self.weight,
self.bias,
self.fuse_matmul_bias,
self.mp_async_allreduce,
self.mp_skip_c_identity,
self.mp_fused_linear_param_grad_add,
self.model_parallel_group,
)
if bias is None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias
return InnerOverlapLinear.apply(x, self.weight, self.bias)
if self.mp_async_allreduce:
output_parallel = _overlap_linear()
......
......@@ -14,6 +14,7 @@
import paddle
from paddle import _legacy_C_ops
from paddle.autograd import PyLayer
from paddle.distributed import collective
from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype
from paddle.framework import LayerHelper, _create_tensor, in_dynamic_mode
......@@ -23,28 +24,10 @@ from paddle.nn.utils import dygraph_utils
from ....communication.reduce import ReduceOp, _get_reduce_op
def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
"""
Return a copy of the tensor, mainly used with model parallel.
Args:
tensor (Tensor): The input Tensor. Its data type
should be float16, float32, float64, int32 or int64.
group (int): The id of the process group to work on.
Returns:
Tensor.
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if in_dynamic_mode():
from paddle.autograd import PyLayer
class c_identity_eager(PyLayer):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
def forward(ctx, tensor, group, skip_c_identity_dynamic):
ctx.group = group
if skip_c_identity_dynamic:
return tensor
else:
......@@ -61,10 +44,28 @@ def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
@staticmethod
def backward(ctx, dy):
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
group.process_group.all_reduce_on_calc_stream(dy, op_type)
ctx.group.process_group.all_reduce_on_calc_stream(dy, op_type)
return dy
return c_identity_eager.apply(tensor)
def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
"""
Return a copy of the tensor, mainly used with model parallel.
Args:
tensor (Tensor): The input Tensor. Its data type
should be float16, float32, float64, int32 or int64.
group (int): The id of the process group to work on.
Returns:
Tensor.
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if in_dynamic_mode():
return c_identity_eager.apply(tensor, group, skip_c_identity_dynamic)
else:
op_type = 'c_identity'
helper = LayerHelper(op_type, **locals())
......@@ -218,36 +219,23 @@ def _c_split(tensor, group=None):
return out
def _mp_allreduce(
tensor,
op=ReduceOp.SUM,
group=None,
use_calc_stream=True,
use_model_parallel=True,
skip_c_identity_dynamic=False,
):
"""[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]"""
if group is not None and not group.is_member():
return
if in_dynamic_mode():
group = collective._get_default_group() if group is None else group
assert op == ReduceOp.SUM, f"Unknown parameter: {op}."
from paddle.autograd import PyLayer
class mp_allreduce_eager(PyLayer):
class mp_allreduce_eager(PyLayer):
@staticmethod
def forward(
ctx, tensor, group, use_calc_stream, use_model_parallel
ctx,
tensor,
group,
use_calc_stream,
use_model_parallel,
op,
skip_c_identity_dynamic,
):
ctx.ring_id = group.id
ctx.skip_c_identity_dynamic = skip_c_identity_dynamic
if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.all_reduce_on_calc_stream(
tensor, op_type
)
group.process_group.all_reduce_on_calc_stream(tensor, op_type)
return tensor
else:
return _legacy_C_ops.c_allreduce_sum_(
......@@ -255,12 +243,12 @@ def _mp_allreduce(
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
group.id,
)
@staticmethod
def backward(ctx, dy):
if skip_c_identity_dynamic:
if ctx.skip_c_identity_dynamic:
return dy
else:
return _legacy_C_ops.c_identity(
......@@ -273,8 +261,29 @@ def _mp_allreduce(
True,
)
def _mp_allreduce(
tensor,
op=ReduceOp.SUM,
group=None,
use_calc_stream=True,
use_model_parallel=True,
skip_c_identity_dynamic=False,
):
"""[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]"""
if group is not None and not group.is_member():
return
if in_dynamic_mode():
group = collective._get_default_group() if group is None else group
assert op == ReduceOp.SUM, f"Unknown parameter: {op}."
return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel
tensor,
group,
use_calc_stream,
use_model_parallel,
op,
skip_c_identity_dynamic,
)
else:
ring_id = 0 if group is None else group.id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册