未验证 提交 a9f877ff 编写于 作者: Y Yuang Liu 提交者: GitHub

[sharding stage 1 optim] Sharding comm overlap with backward (#55598)

上级 b10b899c
...@@ -68,6 +68,8 @@ message PpConfig { ...@@ -68,6 +68,8 @@ message PpConfig {
message DygraphShardingConfig { message DygraphShardingConfig {
optional bool tensor_fusion = 1 [ default = false ]; optional bool tensor_fusion = 1 [ default = false ];
optional int32 accumulate_steps = 2 [ default = 1 ];
optional bool comm_overlap = 3 [ default = false ];
} }
message HybridConfig { message HybridConfig {
......
...@@ -78,12 +78,23 @@ class DygraphShardingOptimizer: ...@@ -78,12 +78,23 @@ class DygraphShardingOptimizer:
self.tensor_fusion = strategy.hybrid_configs[ self.tensor_fusion = strategy.hybrid_configs[
'sharding_configs' 'sharding_configs'
].tensor_fusion ].tensor_fusion
self.accumulate_steps = strategy.hybrid_configs[
'sharding_configs'
].accumulate_steps
self.comm_overlap = strategy.hybrid_configs[
'sharding_configs'
].comm_overlap
pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap
if self.tensor_fusion: if self.tensor_fusion or self.comm_overlap:
assert ( assert (
not pp_overlap not pp_overlap
), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time." ), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time."
self._use_main_grad = hasattr(self._parameter_list[0], "main_grad")
self._rank2decay = {}
self._rank2fused = {}
self._comm_buffers = []
self._rank2params = self._partition_parameters() self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank() self._param2rank = self._map_param_to_rank()
...@@ -95,25 +106,22 @@ class DygraphShardingOptimizer: ...@@ -95,25 +106,22 @@ class DygraphShardingOptimizer:
'_param_groups', self._rank2params[self._sharding_rank] '_param_groups', self._rank2params[self._sharding_rank]
) )
else: else:
self._use_main_grad = hasattr(self._parameter_list[0], "main_grad")
self._rank2decay = {}
self._rank2fused = {}
self._tensor_fusion() self._tensor_fusion()
decay_params = [ decay_params = [
p.name for p in self._rank2decay[self._sharding_rank] p.name for p in self._rank2decay[self._sharding_rank]
] ]
all_params = self._rank2fused[self._sharding_rank] fused_params = self._rank2fused[self._sharding_rank]
apply_decay_param_fun = lambda x: x in decay_params apply_decay_param_fun = lambda x: x in decay_params
params = [] all_fused_params = []
for v in self._rank2fused.values(): for v in self._rank2fused.values():
params += v all_fused_params += v
self._parameter_list = params self._parameter_list = all_fused_params
self._param_groups = params self._param_groups = all_fused_params
self._set_inner_opt_attr('_parameter_list', all_params) self._set_inner_opt_attr('_parameter_list', fused_params)
self._set_inner_opt_attr('_param_groups', all_params) self._set_inner_opt_attr('_param_groups', fused_params)
origin_decay_param_fun = getattr( origin_decay_param_fun = getattr(
self._inner_opt, '_apply_decay_param_fun', None self._inner_opt, '_apply_decay_param_fun', None
) )
...@@ -145,11 +153,23 @@ class DygraphShardingOptimizer: ...@@ -145,11 +153,23 @@ class DygraphShardingOptimizer:
p.clear_gradient(set_to_zero) p.clear_gradient(set_to_zero)
def _tensor_fusion(self): def _tensor_fusion(self):
comm_group = self._hcg.get_sharding_parallel_group()
for i in range(self._sharding_world_size): for i in range(self._sharding_world_size):
params = self._rank2params[i] params = self._rank2params[i]
decay_fused, all_fused = fused_parameters( dst = comm_group.ranks[i]
params, self._use_main_grad # TODO(sharding dev): make scale_after_comm a field to be configured by user
decay_fused, all_fused, all_buffer = fused_parameters(
params,
use_main_grad=self._use_main_grad,
fuse_param=True,
comm_overlap=self.comm_overlap,
comm_group=comm_group,
dst=dst,
acc_step=self.accumulate_steps,
scale_after_comm=False,
) )
if self.comm_overlap:
self._comm_buffers += all_buffer
self._rank2decay[i] = decay_fused self._rank2decay[i] = decay_fused
self._rank2fused[i] = all_fused self._rank2fused[i] = all_fused
for p in all_fused: for p in all_fused:
...@@ -199,6 +219,10 @@ class DygraphShardingOptimizer: ...@@ -199,6 +219,10 @@ class DygraphShardingOptimizer:
def reduce_gradients(self, parameter_list, hcg): def reduce_gradients(self, parameter_list, hcg):
# TODO merge grad / nrank with dp # TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync") logger.debug("sharding start gradients sync")
if self.comm_overlap:
for buffer in self._comm_buffers:
buffer.scale_grads()
return
with framework.no_grad(): with framework.no_grad():
sharding_nrank = hcg.get_sharding_parallel_group().nranks sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list: for param in parameter_list:
......
...@@ -37,11 +37,11 @@ else: ...@@ -37,11 +37,11 @@ else:
from .pp_utils import p2p_communication as p2p from .pp_utils import p2p_communication as p2p
from paddle.distributed.fleet.utils.tensor_fusion_helper import ( from paddle.distributed.fleet.utils.tensor_fusion_helper import (
HOOK_ACTION,
FusedCommBuffer,
assign_group_by_size, assign_group_by_size,
) )
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer
__all__ = [] __all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0)) g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
...@@ -334,9 +334,11 @@ class PipelineParallel(MetaParallelBase): ...@@ -334,9 +334,11 @@ class PipelineParallel(MetaParallelBase):
for dst in fused_parameter_group: for dst in fused_parameter_group:
parameter_list = fused_parameter_group[dst] parameter_list = fused_parameter_group[dst]
if not dp: if act != HOOK_ACTION.ALL_REDUCE:
# parse the relative dst rank to absolute dst rank for sharding # parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst] dst = comm_group.ranks[dst]
else:
dst = -1
var_groups = assign_group_by_size(parameter_list) var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items(): for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer( buffer = FusedCommBuffer(
...@@ -515,7 +517,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -515,7 +517,7 @@ class PipelineParallel(MetaParallelBase):
if self._comm_overlap: if self._comm_overlap:
assert len(self._comm_buffers) > 0 assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers: for buffer in self._comm_buffers:
buffer.scale_and_split_grads() buffer.scale_grads()
if self._enable_timer: if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start() self.timers("allreduce_shared_weight_gradients").start()
...@@ -1256,7 +1258,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -1256,7 +1258,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self._comm_overlap: if self._comm_overlap:
assert len(self._comm_buffers) > 0 assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers: for buffer in self._comm_buffers:
buffer.scale_and_split_grads() buffer.scale_grads()
if static_scheduler: if static_scheduler:
self._reset_counter() self._reset_counter()
......
...@@ -15,19 +15,10 @@ ...@@ -15,19 +15,10 @@
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _legacy_C_ops
from paddle.distributed.fleet.utils.tensor_fusion_helper import (
flatten_dense_tensors,
)
from paddle.framework import base as imperative_base
__all__ = [] __all__ = []
class HOOK_ACTION:
ALL_REDUCE = 0
REDUCE = 1
FLOAT_TYPE_DICT = { FLOAT_TYPE_DICT = {
paddle.float16: "float16", paddle.float16: "float16",
paddle.float32: "float32", paddle.float32: "float32",
...@@ -116,118 +107,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ...@@ -116,118 +107,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
'nranks', 'nranks',
nranks, nranks,
) )
class FusedCommBuffer:
def __init__(self, id, params, comm_group, acc_steps=1, act=None, dst=-1):
self._id = id
self._params = params
self._acc_steps = acc_steps
self._comm_group = comm_group
self.use_main_grad = hasattr(self._params[0], "main_grad")
self._task = None
self._params_step_dict = {}
self._params_checked_in = 0
self._params_to_addr = {}
self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
assert dst == -1
elif self._act == HOOK_ACTION.REDUCE:
assert dst != -1
else:
raise ValueError(
"The act should be allreudce for dp or reduce for sharding."
)
self._dst = dst
self._init_step_dict()
self.grad_storage = flatten_dense_tensors(
self._params,
use_main_grad=self.use_main_grad,
fuse_param=False,
warp_buffer=False,
).buffer
self._record_addr()
def _record_addr(self):
for param in self._params:
addr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
self._params_to_addr[param.name] = addr
def _init_step_dict(self):
for p in self._params:
self._params_step_dict[p.name] = 0
def _reset_params_checked_in(self):
self._task = None
self._init_step_dict()
self._params_checked_in = 0
@property
def _all_params_checked_in(self):
return (
len(self._params) == self._params_checked_in
and len(self._params_step_dict) == 0
)
def add_grad(self, param):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
if self._params_to_addr[param.name] != current_ptr:
raise ValueError(
"The address of the grad/main_grad of the param has been changed during training, "
"which is not allowed for dp/sharding overlap with pp. "
"This may be caused by some non-inplace operations on the grad/main_grad. "
"Please use the inplace version of the operations or disable the overlapping."
)
self._params_step_dict[param.name] += 1
if self._params_step_dict[param.name] == self._acc_steps:
self._params_checked_in += 1
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._comm_grads()
@imperative_base.no_grad
def _comm_grads(self):
assert self._all_params_checked_in
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._task = task
@imperative_base.no_grad
def scale_and_split_grads(self):
assert self._task is not None
self._task.wait()
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
self._reset_params_checked_in()
...@@ -12,13 +12,21 @@ ...@@ -12,13 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools import itertools
import os
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import paddle import paddle
from paddle.framework import base as imperative_base
from paddle.framework import core from paddle.framework import core
class HOOK_ACTION:
ALL_REDUCE = 0
REDUCE = 1
alignment = { alignment = {
"gpu": 256, "gpu": 256,
} }
...@@ -101,23 +109,204 @@ def flatten_dense_tensors( ...@@ -101,23 +109,204 @@ def flatten_dense_tensors(
return grad_storage return grad_storage
def obtain_storage(parameters, use_main_grad, clip, dist): def bw_hook_func(buffer, param):
@paddle.autograd.no_grad()
def fused_comm(*_):
buffer.add_grad(param)
return fused_comm
class FusedCommBuffer:
def __init__(
self,
id,
params,
comm_group,
acc_steps=1,
act=None,
dst=-1,
use_main_grad=None,
fuse_param=False,
scale_after_comm=True,
):
self._id = id
self._params = params
self._acc_steps = acc_steps
self._comm_group = comm_group
self._scale_after_comm = scale_after_comm
self._fuse_param = fuse_param
self.use_main_grad = (
use_main_grad
if use_main_grad is not None
else hasattr(self._params[0], "main_grad")
)
self._task = None
self._params_step_dict = {}
self._params_checked_in = 0
self._grads_to_addr = {}
self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
assert dst == -1
elif self._act == HOOK_ACTION.REDUCE:
assert dst != -1
else:
raise ValueError(
"The act should be allreudce for dp or reduce for sharding."
)
self._dst = dst
self._init_step_dict()
if self._fuse_param:
self.param_storage, self.grad_storage = flatten_dense_tensors(
self._params,
use_main_grad=use_main_grad,
fuse_param=True,
warp_buffer=True,
)
self.param_storage = self.param_storage.buffer
self.grad_storage = self.grad_storage.buffer
else:
self.param_storage = None
self.grad_storage = flatten_dense_tensors(
self._params,
use_main_grad=self.use_main_grad,
fuse_param=False,
warp_buffer=False,
).buffer
self._record_addr()
def _record_addr(self):
for param in self._params:
addr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
self._grads_to_addr[param.name] = addr
def _init_step_dict(self):
for p in self._params:
self._params_step_dict[p.name] = 0
def _reset_params_checked_in(self):
self._task = None
self._init_step_dict()
self._params_checked_in = 0
@property
def _all_params_checked_in(self):
return (
len(self._params) == self._params_checked_in
and len(self._params_step_dict) == 0
)
def add_grad(self, param):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
if self._grads_to_addr[param.name] != current_ptr:
raise ValueError(
"The address of the grad/main_grad of the param has been changed during training, "
"which is not allowed for dp/sharding overlap with pp. "
"This may be caused by some non-inplace operations on the grad/main_grad. "
"Please use the inplace version of the operations or disable the overlapping."
)
self._params_step_dict[param.name] += 1
if self._params_step_dict[param.name] == self._acc_steps:
self._params_checked_in += 1
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._comm_grads()
@imperative_base.no_grad
def _comm_grads(self):
assert self._all_params_checked_in
if not self._scale_after_comm:
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._task = task
@imperative_base.no_grad
def scale_grads(self):
assert self._task is not None
self._task.wait()
if self._scale_after_comm:
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
self._reset_params_checked_in()
def obtain_storage(
parameters,
use_main_grad=False,
clip=True,
dist=False,
fuse_param=True,
comm_overlap=False,
act=None,
comm_group=None,
dst=-1,
acc_steps=1,
scale_after_comm=False,
):
if len(parameters) < 1: if len(parameters) < 1:
return [] return [], []
var_groups = assign_group_by_size(parameters, group_size=256 * 1024 * 1024) var_groups = assign_group_by_size(parameters, group_size=256 * 1024 * 1024)
storage = [] storage = []
buffers = []
for group_idx, parameters in var_groups.items(): for group_idx, parameters in var_groups.items():
param_storage, grad_storage = flatten_dense_tensors( comm_buffer = FusedCommBuffer(
group_idx,
parameters, parameters,
comm_group=comm_group,
acc_steps=acc_steps,
act=act,
dst=dst,
use_main_grad=use_main_grad, use_main_grad=use_main_grad,
fuse_param=True, fuse_param=fuse_param,
warp_buffer=True, scale_after_comm=scale_after_comm,
) )
param_storage.buffer.need_clip = clip if fuse_param:
param_storage.buffer.is_distributed = dist param_buffer = comm_buffer.param_storage
storage.append(param_storage.buffer) param_buffer.need_clip = clip
return storage param_buffer.is_distributed = dist
storage.append(param_buffer)
if comm_overlap:
for param in parameters:
param._register_backward_hook(bw_hook_func(comm_buffer, param))
buffers.append(comm_buffer)
return storage, buffers
def filter_params(params, is_fp32, is_distributed, need_clip): def filter_params(params, is_fp32, is_distributed, need_clip):
...@@ -155,7 +344,38 @@ def filter_params(params, is_fp32, is_distributed, need_clip): ...@@ -155,7 +344,38 @@ def filter_params(params, is_fp32, is_distributed, need_clip):
return params, dtype return params, dtype
def fused_parameters(parameters, use_main_grad): def fused_parameters(
parameters,
use_main_grad=False,
fuse_param=True,
comm_overlap=False,
comm_group=None,
dst=-1,
acc_step=1,
scale_after_comm=False,
):
"""
Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled.
:param parameters: all parameters to be fused.
:param use_main_grad: does the gradient use main grad or not
:param comm_overlap: enable comm overlap or not
:param comm_group: the comm group for comm overlap
:param dst: the dst for comm overlap
:param acc_step: acc steps, using for comm overlap
:param fuse_param: fuse param or not
:param scale_after_comm: if enable comm overlap, specify the location of grad scale
:return: param storage if fused, comm buffers is comm overlap
"""
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
act = (
HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE
)
if comm_overlap:
assert comm_group is not None
if act == HOOK_ACTION.REDUCE:
assert dst != -1
elif act == HOOK_ACTION.ALL_REDUCE:
dst = -1
param_groups = [] param_groups = []
attrs = [] attrs = []
...@@ -178,6 +398,7 @@ def fused_parameters(parameters, use_main_grad): ...@@ -178,6 +398,7 @@ def fused_parameters(parameters, use_main_grad):
decay_fused = [] decay_fused = []
all_fused = [] all_fused = []
all_buffers = []
for params, attr in zip(param_groups, attrs): for params, attr in zip(param_groups, attrs):
decay_params = [] decay_params = []
other_params = [] other_params = []
...@@ -190,14 +411,36 @@ def fused_parameters(parameters, use_main_grad): ...@@ -190,14 +411,36 @@ def fused_parameters(parameters, use_main_grad):
is_distributed = attr[1] is_distributed = attr[1]
need_clip = attr[2] need_clip = attr[2]
decay = obtain_storage( decay, decay_buffers = obtain_storage(
decay_params, use_main_grad, need_clip, is_distributed decay_params,
use_main_grad=use_main_grad,
clip=need_clip,
dist=is_distributed,
fuse_param=fuse_param,
comm_overlap=comm_overlap,
act=act,
comm_group=comm_group,
dst=dst,
acc_steps=acc_step,
scale_after_comm=scale_after_comm,
) )
other = obtain_storage( other, other_buffers = obtain_storage(
other_params, use_main_grad, need_clip, is_distributed other_params,
fuse_param=fuse_param,
comm_overlap=comm_overlap,
use_main_grad=use_main_grad,
clip=need_clip,
dist=is_distributed,
act=act,
comm_group=comm_group,
dst=dst,
acc_steps=acc_step,
scale_after_comm=scale_after_comm,
) )
decay_fused += decay decay_fused += decay
all_fused += decay all_fused += decay
all_fused += other all_fused += other
all_buffers += decay_buffers
all_buffers += other_buffers
return decay_fused, all_fused return decay_fused, all_fused, all_buffers
...@@ -99,6 +99,8 @@ class TestDistSharding(unittest.TestCase): ...@@ -99,6 +99,8 @@ class TestDistSharding(unittest.TestCase):
"pp_degree": 1, "pp_degree": 1,
} }
self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True
self.strategy.hybrid_configs["sharding_configs"].comm_overlap = True
self.strategy.hybrid_configs["sharding_configs"].accumulate_steps = 1
fleet.init(is_collective=True, strategy=self.strategy) fleet.init(is_collective=True, strategy=self.strategy)
self.data = np.random.randint( self.data = np.random.randint(
0, 0,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import paddle import paddle
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import ( from paddle.distributed.fleet.utils.tensor_fusion_helper import (
HOOK_ACTION, HOOK_ACTION,
FusedCommBuffer, FusedCommBuffer,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册