未验证 提交 88a975a0 编写于 作者: G Ghost Screaming 提交者: GitHub

Change flags for mp async all reduce (#56456)

上级 7577a67a
...@@ -56,6 +56,12 @@ message MpConfig { ...@@ -56,6 +56,12 @@ message MpConfig {
optional bool sync_grad= 2 [ default = false ]; optional bool sync_grad= 2 [ default = false ];
optional bool sync_moment= 3 [ default = false ]; optional bool sync_moment= 3 [ default = false ];
optional string sync_mode= 4 [ default = 'broadcast' ]; optional string sync_mode= 4 [ default = 'broadcast' ];
// Support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
optional bool mp_async_allreduce= 5 [default = false ];
// Support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when mp_async_allreduce is true.
optional bool mp_skip_c_identity= 6 [default = false ];
// Support fused_linear_param_grad_add in ColumnParallelLinear. Only works when mp_async_allreduce is true.
optional bool mp_fused_linear_param_grad_add= 7 [default = false ];
} }
message PpConfig { message PpConfig {
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
import paddle import paddle
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.fluid import core from paddle.fluid import core
from paddle.nn import functional as F from paddle.nn import functional as F
from ....communication.reduce import ReduceOp, _get_reduce_op from ....communication.reduce import ReduceOp, _get_reduce_op
from ...base import topology as tp from ...base import topology as tp
from . import mp_ops from . import mp_ops
from .mp_ops import _get_mp_env_flag
from .random import get_rng_state_tracker from .random import get_rng_state_tracker
__all__ = [] __all__ = []
...@@ -305,6 +305,21 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -305,6 +305,21 @@ class ColumnParallelLinear(paddle.nn.Layer):
self.linear = F.linear self.linear = F.linear
self.fuse_matmul_bias = fuse_matmul_bias self.fuse_matmul_bias = fuse_matmul_bias
mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"mp_configs"
]
self.mp_async_allreduce = self.is_mp and mp_configs.mp_async_allreduce
self.mp_skip_c_identity = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_skip_c_identity
)
self.mp_fused_linear_param_grad_add = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_fused_linear_param_grad_add
)
if self.fuse_matmul_bias: if self.fuse_matmul_bias:
if not is_fused_matmul_bias_supported(): if not is_fused_matmul_bias_supported():
raise NotImplementedError( raise NotImplementedError(
...@@ -322,15 +337,15 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -322,15 +337,15 @@ class ColumnParallelLinear(paddle.nn.Layer):
def _overlap_linear(): def _overlap_linear():
fuse_matmul_bias = self.fuse_matmul_bias fuse_matmul_bias = self.fuse_matmul_bias
mp_async_allreduce = self.mp_async_allreduce
mp_skip_c_identity = self.mp_skip_c_identity
mp_fused_linear_param_grad_add = self.mp_fused_linear_param_grad_add
class InnerOverlapLinear(paddle.autograd.PyLayer): class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod @staticmethod
def forward(ctx, x, weight, bias): def forward(ctx, x, weight, bias):
ctx.save_for_backward(x, weight, bias) ctx.save_for_backward(x, weight, bias)
if ( if mp_skip_c_identity is False:
_get_mp_env_flag("Flags_mp_aysnc_allreduce")
and _get_mp_env_flag("Flags_skip_mp_c_identity")
) is False:
x = paddle._legacy_C_ops.c_identity( x = paddle._legacy_C_ops.c_identity(
x, x,
'use_calc_stream', 'use_calc_stream',
...@@ -358,12 +373,12 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -358,12 +373,12 @@ class ColumnParallelLinear(paddle.nn.Layer):
# TODO(GhostScreaming): remove it in future. # TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512]) tmp = paddle.ones([512])
if _get_mp_env_flag("Flags_fused_linear_param_grad_add"): if mp_fused_linear_param_grad_add:
if not is_fused_linear_param_grad_add_supported(): if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError( raise NotImplementedError(
"You set environment variable Flags_fused_linear_param_grad_add=True, " "You set mp_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. " "however, the paddle you are using not support this operation. "
"Please unset Flags_fused_linear_param_grad_add or use paddle compiled " "Please unset fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher." "with cuda 11.6 or higher."
) )
...@@ -449,12 +464,14 @@ class ColumnParallelLinear(paddle.nn.Layer): ...@@ -449,12 +464,14 @@ class ColumnParallelLinear(paddle.nn.Layer):
return InnerOverlapLinear.apply(x, self.weight, self.bias) return InnerOverlapLinear.apply(x, self.weight, self.bias)
if _get_mp_env_flag("Flags_mp_aysnc_allreduce"): if self.mp_async_allreduce:
output_parallel = _overlap_linear() output_parallel = _overlap_linear()
else: else:
if self.is_mp: if self.is_mp:
input_parallel = mp_ops._c_identity( input_parallel = mp_ops._c_identity(
x, group=self.model_parallel_group x,
group=self.model_parallel_group,
skip_c_identity_dynamic=self.mp_skip_c_identity,
) )
else: else:
input_parallel = x input_parallel = x
...@@ -570,6 +587,20 @@ class RowParallelLinear(paddle.nn.Layer): ...@@ -570,6 +587,20 @@ class RowParallelLinear(paddle.nn.Layer):
) )
self.is_mp = self.world_size > 1 self.is_mp = self.world_size > 1
mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"mp_configs"
]
self.mp_async_allreduce = self.is_mp and mp_configs.mp_async_allreduce
self.mp_skip_c_identity = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_skip_c_identity
)
self.mp_fused_linear_param_grad_add = (
self.is_mp
and mp_configs.mp_async_allreduce
and mp_configs.mp_fused_linear_param_grad_add
)
assert in_features % self.world_size == 0, ( assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be" "Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format( " divisible by model parallel size ({})".format(
...@@ -642,6 +673,7 @@ class RowParallelLinear(paddle.nn.Layer): ...@@ -642,6 +673,7 @@ class RowParallelLinear(paddle.nn.Layer):
group=self.model_parallel_group, group=self.model_parallel_group,
use_calc_stream=True, use_calc_stream=True,
use_model_parallel=True, use_model_parallel=True,
skip_c_identity_dynamic=self.mp_skip_c_identity,
) )
else: else:
output_parallel = self.linear( output_parallel = self.linear(
...@@ -652,6 +684,7 @@ class RowParallelLinear(paddle.nn.Layer): ...@@ -652,6 +684,7 @@ class RowParallelLinear(paddle.nn.Layer):
group=self.model_parallel_group, group=self.model_parallel_group,
use_calc_stream=True, use_calc_stream=True,
use_model_parallel=True, use_model_parallel=True,
skip_c_identity_dynamic=self.mp_skip_c_identity,
) )
output = ( output = (
output_ + self.bias if self.bias is not None else output_ output_ + self.bias if self.bias is not None else output_
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _legacy_C_ops
from paddle.distributed import collective from paddle.distributed import collective
...@@ -24,40 +22,8 @@ from paddle.nn.utils import dygraph_utils ...@@ -24,40 +22,8 @@ from paddle.nn.utils import dygraph_utils
from ....communication.reduce import ReduceOp, _get_reduce_op from ....communication.reduce import ReduceOp, _get_reduce_op
_first_get_mp_env_flag = True
def _get_mp_env_flag(flag):
global _first_get_mp_env_flag
if _first_get_mp_env_flag:
print(
"Flags_mp_aysnc_allreduce is {}, which is used to support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear.".format(
str(os.getenv("Flags_mp_aysnc_allreduce")).lower()
)
)
print(
"Flags_fused_linear_param_grad_add is {}, which is used to support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_fused_linear_param_grad_add")).lower()
)
)
print(
"Flags_skip_mp_c_identity is {}, which is used to support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_skip_mp_c_identity")).lower()
)
)
# Model parallel environment flag.
# Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
# Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
# Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
assert flag in [
"Flags_mp_aysnc_allreduce",
"Flags_fused_linear_param_grad_add",
"Flags_skip_mp_c_identity",
], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)"
return str(os.getenv(flag)).lower() in ["true", "1"]
def _c_identity(tensor, group=None): def _c_identity(tensor, group=None, skip_c_identity_dynamic=False):
""" """
Return a copy of the tensor, mainly used with model parallel. Return a copy of the tensor, mainly used with model parallel.
...@@ -79,9 +45,7 @@ def _c_identity(tensor, group=None): ...@@ -79,9 +45,7 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer): class c_identity_eager(PyLayer):
@staticmethod @staticmethod
def forward(ctx, tensor): def forward(ctx, tensor):
if _get_mp_env_flag( if skip_c_identity_dynamic:
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return tensor return tensor
else: else:
return _legacy_C_ops.c_identity( return _legacy_C_ops.c_identity(
...@@ -260,6 +224,7 @@ def _mp_allreduce( ...@@ -260,6 +224,7 @@ def _mp_allreduce(
group=None, group=None,
use_calc_stream=True, use_calc_stream=True,
use_model_parallel=True, use_model_parallel=True,
skip_c_identity_dynamic=False,
): ):
"""[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]""" """[it is same as allreduce above, but it supports model parallel. And it support inplace startegy]"""
if group is not None and not group.is_member(): if group is not None and not group.is_member():
...@@ -295,9 +260,7 @@ def _mp_allreduce( ...@@ -295,9 +260,7 @@ def _mp_allreduce(
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
if _get_mp_env_flag( if skip_c_identity_dynamic:
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return dy return dy
else: else:
return _legacy_C_ops.c_identity( return _legacy_C_ops.c_identity(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册