未验证 提交 88a975a0 编写于 作者: G Ghost Screaming 提交者: GitHub

Change flags for mp async all reduce (#56456)

上级 7577a67a
......@@ -56,6 +56,12 @@ message MpConfig {
optional bool sync_grad= 2 [ default = false ];
optional bool sync_moment= 3 [ default = false ];
optional string sync_mode= 4 [ default = 'broadcast' ];
// Support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
optional bool mp_async_allreduce= 5 [default = false ];
// Support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when mp_async_allreduce is true.
optional bool mp_skip_c_identity= 6 [default = false ];
// Support fused_linear_param_grad_add in ColumnParallelLinear. Only works when mp_async_allreduce is true.
optional bool mp_fused_linear_param_grad_add= 7 [default = false ];
}
message PpConfig {
......
......@@ -14,13 +14,13 @@
import paddle
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.fluid import core
from paddle.nn import functional as F
from ....communication.reduce import ReduceOp, _get_reduce_op
from ...base import topology as tp
from . import mp_ops
from .mp_ops import _get_mp_env_flag
from .random import get_rng_state_tracker
__all__ = []
......@@ -305,6 +305,21 @@ class ColumnParallelLinear(paddle.nn.Layer):
self.linear = F.linear
self.fuse_matmul_bias = fuse_matmul_bias
mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"mp_configs"
]
self.mp_async_allreduce = self.is_mp and mp_configs.mp_async_allreduce
self.mp_skip_c_identity = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_skip_c_identity
)
self.mp_fused_linear_param_grad_add = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_fused_linear_param_grad_add
)
if self.fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
......@@ -322,15 +337,15 @@ class ColumnParallelLinear(paddle.nn.Layer):
def _overlap_linear():
fuse_matmul_bias = self.fuse_matmul_bias
mp_async_allreduce = self.mp_async_allreduce
mp_skip_c_identity = self.mp_skip_c_identity
mp_fused_linear_param_grad_add = self.mp_fused_linear_param_grad_add
class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias):
ctx.save_for_backward(x, weight, bias)
if (
_get_mp_env_flag("Flags_mp_aysnc_allreduce")
and _get_mp_env_flag("Flags_skip_mp_c_identity")
) is False:
if mp_skip_c_identity is False:
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
......@@ -358,12 +373,12 @@ class ColumnParallelLinear(paddle.nn.Layer):
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])
if _get_mp_env_flag("Flags_fused_linear_param_grad_add"):
if mp_fused_linear_param_grad_add:
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set environment variable Flags_fused_linear_param_grad_add=True, "
"You set mp_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset Flags_fused_linear_param_grad_add or use paddle compiled "
"Please unset fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)
......@@ -449,12 +464,14 @@ class ColumnParallelLinear(paddle.nn.Layer):
return InnerOverlapLinear.apply(x, self.weight, self.bias)
if _get_mp_env_flag("Flags_mp_aysnc_allreduce"):
if self.mp_async_allreduce:
output_parallel = _overlap_linear()
else:
if self.is_mp:
input_parallel = mp_ops._c_identity(
x, group=self.model_parallel_group
x,
group=self.model_parallel_group,
skip_c_identity_dynamic=self.mp_skip_c_identity,
)
else:
input_parallel = x
......@@ -570,6 +587,20 @@ class RowParallelLinear(paddle.nn.Layer):
)
self.is_mp = self.world_size > 1
mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"mp_configs"
]
self.mp_async_allreduce = self.is_mp and mp_configs.mp_async_allreduce
self.mp_skip_c_identity = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_skip_c_identity
)
self.mp_fused_linear_param_grad_add = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_fused_linear_param_grad_add
)
assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(
......@@ -642,6 +673,7 @@ class RowParallelLinear(paddle.nn.Layer):
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
skip_c_identity_dynamic=self.mp_skip_c_identity,
)
else:
output_parallel = self.linear(
......@@ -652,6 +684,7 @@ class RowParallelLinear(paddle.nn.Layer):
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
skip_c_identity_dynamic=self.mp_skip_c_identity,
)
output = (
output_ + self.bias if self.bias is not None else output_
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from paddle import _legacy_C_ops
from paddle.distributed import collective
......@@ -24,40 +22,8 @@ from paddle.nn.utils import dygraph_utils
from ....communication.reduce import ReduceOp, _get_reduce_op
_first_get_mp_env_flag = True
def _get_mp_env_flag(flag):
global _first_get_mp_env_flag
if _first_get_mp_env_flag:
print(
"Flags_mp_aysnc_allreduce is {}, which is used to support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear.".format(
str(os.getenv("Flags_mp_aysnc_allreduce")).lower()
)
)
print(
"Flags_fused_linear_param_grad_add is {}, which is used to support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_fused_linear_param_grad_add")).lower()
)
)
print(
"Flags_skip_mp_c_identity is {}, which is used to support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_skip_mp_c_identity")).lower()
)
)
# Model parallel environment flag.
# Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
# Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
# Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
assert flag in [
"Flags_mp_aysnc_allreduce",
"Flags_fused_linear_param_grad_add",
"Flags_skip_mp_c_identity",
], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)"
return str(os.getenv(flag)).lower() in ["true", "1"]
def _c_identity(tensor, group=None):
def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
"""
Return a copy of the tensor, mainly used with model parallel.
......@@ -79,9 +45,7 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
if _get_mp_env_flag(
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
if skip_c_identity_dynamic:
return tensor
else:
return _legacy_C_ops.c_identity(
......@@ -260,6 +224,7 @@ def _mp_allreduce(
group=None,
use_calc_stream=True,
use_model_parallel=True,
skip_c_identity_dynamic=False,
):
"""[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]"""
if group is not None and not group.is_member():
......@@ -295,9 +260,7 @@ def _mp_allreduce(
@staticmethod
def backward(ctx, dy):
if _get_mp_env_flag(
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
if skip_c_identity_dynamic:
return dy
else:
return _legacy_C_ops.c_identity(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册