未验证 提交 6b1dfb5f 编写于 作者: G Ghost Screaming 提交者: GitHub

Add mp_all_reduce asynchronize overlap. (#55662)

* [WIP] Add mp_all_reduce asynchronize overlap.

* Fix some problems.

* Fix dw compute bug, and use a temporary solution to achieve overlap.

* Use fused_linear_param_grad_add to compute dw.

* Reformat ColumnParallel _overlap_linear. Use environment flags to
control following behaviors:
1. export Flags_mp_aysnc_allreduce=True to turn on mp async all_reduce
2. export Flags_skip_mp_c_identity=True to skip two c_identity operators
   in dygraph mode.
3. export Flags_fused_linear_param_grad_add to enable fused_linear_param_grad_add
   in ColumnParallel backward with mp async all_reduce.

* Polish code.

* Remove useless communication API.

* Fix some problems in mp_async_all_reduce and skip_c_identity.

* Add test cases.

* Remove environment variable Flags_fused_linear_param_grad_add in test case.

* Reset error threshold.

* Reset threshold in test case.

* Add useful log. Remove useless test cases.
上级 a8981be0
...@@ -17,8 +17,10 @@ from paddle.autograd import PyLayer ...@@ -17,8 +17,10 @@ from paddle.autograd import PyLayer
from paddle.fluid import core from paddle.fluid import core
from paddle.nn import functional as F from paddle.nn import functional as F
from ....communication.reduce import ReduceOp, _get_reduce_op
from ...base import topology as tp from ...base import topology as tp
from . import mp_ops from . import mp_ops
from .mp_ops import _get_mp_env_flag
from .random import get_rng_state_tracker from .random import get_rng_state_tracker
__all__ = [] __all__ = []
...@@ -32,6 +34,13 @@ def is_fused_matmul_bias_supported(): ...@@ -32,6 +34,13 @@ def is_fused_matmul_bias_supported():
return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue') return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue')
def is_fused_linear_param_grad_add_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(paddle._C_ops, 'fused_linear_param_grad_add')
else:
return False
class VocabParallelEmbedding(paddle.nn.Layer): class VocabParallelEmbedding(paddle.nn.Layer):
"""Embedding mp parallelized in the vocabulary dimension. """Embedding mp parallelized in the vocabulary dimension.
this class is used for splitting embedding in mp group. this class is used for splitting embedding in mp group.
...@@ -295,7 +304,8 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -295,7 +304,8 @@ class ColumnParallelLinear(paddle.nn.Layer):
self.linear = F.linear self.linear = F.linear
if fuse_matmul_bias: self.fuse_matmul_bias = fuse_matmul_bias
if self.fuse_matmul_bias:
if not is_fused_matmul_bias_supported(): if not is_fused_matmul_bias_supported():
raise NotImplementedError( raise NotImplementedError(
"You set fuse_matmul_bias=True in ColumnParallelLinear, " "You set fuse_matmul_bias=True in ColumnParallelLinear, "
...@@ -309,6 +319,139 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -309,6 +319,139 @@ class ColumnParallelLinear(paddle.nn.Layer):
def forward(self, x): def forward(self, x):
# use inner api to process identity # use inner api to process identity
def _overlap_linear():
fuse_matmul_bias = self.fuse_matmul_bias
class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias):
ctx.save_for_backward(x, weight, bias)
if (
_get_mp_env_flag("Flags_mp_aysnc_allreduce")
and _get_mp_env_flag("Flags_skip_mp_c_identity")
) is False:
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
True,
'ring_id',
self.model_parallel_group.id,
'use_model_parallel',
True,
)
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(
x, weight, bias
)
@staticmethod
def backward(ctx, dy):
x, weight, bias = ctx.saved_tensor()
dx = paddle.matmul(dy, weight, transpose_y=True)
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
task = self.model_parallel_group.process_group.all_reduce(
dx, op_type, sync_op=False
)
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])
if _get_mp_env_flag("Flags_fused_linear_param_grad_add"):
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set environment variable Flags_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset Flags_fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)
if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, False
)
task.wait()
return dx, dw
if hasattr(weight, "main_grad") and hasattr(
bias, "main_grad"
):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dw = paddle.matmul(
x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if bias is None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias
return InnerOverlapLinear.apply(x, self.weight, self.bias)
if _get_mp_env_flag("Flags_mp_aysnc_allreduce"):
output_parallel = _overlap_linear()
else:
if self.is_mp: if self.is_mp:
input_parallel = mp_ops._c_identity( input_parallel = mp_ops._c_identity(
x, group=self.model_parallel_group x, group=self.model_parallel_group
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _legacy_C_ops
from paddle.distributed import collective from paddle.distributed import collective
...@@ -22,6 +24,38 @@ from paddle.nn.utils import dygraph_utils ...@@ -22,6 +24,38 @@ from paddle.nn.utils import dygraph_utils
from ....communication.reduce import ReduceOp, _get_reduce_op from ....communication.reduce import ReduceOp, _get_reduce_op
_first_get_mp_env_flag = True
def _get_mp_env_flag(flag):
global _first_get_mp_env_flag
if _first_get_mp_env_flag:
print(
"Flags_mp_aysnc_allreduce is {}, which is used to support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear.".format(
str(os.getenv("Flags_mp_aysnc_allreduce")).lower()
)
)
print(
"Flags_fused_linear_param_grad_add is {}, which is used to support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_fused_linear_param_grad_add")).lower()
)
)
print(
"Flags_skip_mp_c_identity is {}, which is used to support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_skip_mp_c_identity")).lower()
)
)
# Model parallel environment flag.
# Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
# Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
# Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
assert flag in [
"Flags_mp_aysnc_allreduce",
"Flags_fused_linear_param_grad_add",
"Flags_skip_mp_c_identity",
], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)"
return str(os.getenv(flag)).lower() in ["true", "1"]
def _c_identity(tensor, group=None): def _c_identity(tensor, group=None):
""" """
...@@ -45,6 +79,11 @@ def _c_identity(tensor, group=None): ...@@ -45,6 +79,11 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer): class c_identity_eager(PyLayer):
@staticmethod @staticmethod
def forward(ctx, tensor): def forward(ctx, tensor):
if _get_mp_env_flag(
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return tensor
else:
return _legacy_C_ops.c_identity( return _legacy_C_ops.c_identity(
tensor, tensor,
'use_calc_stream', 'use_calc_stream',
...@@ -256,6 +295,11 @@ def _mp_allreduce( ...@@ -256,6 +295,11 @@ def _mp_allreduce(
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
if _get_mp_env_flag(
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return dy
else:
return _legacy_C_ops.c_identity( return _legacy_C_ops.c_identity(
dy, dy,
'use_calc_stream', 'use_calc_stream',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册