未验证 提交 a9f877ff 编写于 作者: Y Yuang Liu 提交者: GitHub

[sharding stage 1 optim] Sharding comm overlap with backward (#55598)

上级 b10b899c
......@@ -68,6 +68,8 @@ message PpConfig {
message DygraphShardingConfig {
optional bool tensor_fusion = 1 [ default = false ];
optional int32 accumulate_steps = 2 [ default = 1 ];
optional bool comm_overlap = 3 [ default = false ];
}
message HybridConfig {
......
......@@ -78,12 +78,23 @@ class DygraphShardingOptimizer:
self.tensor_fusion = strategy.hybrid_configs[
'sharding_configs'
].tensor_fusion
self.accumulate_steps = strategy.hybrid_configs[
'sharding_configs'
].accumulate_steps
self.comm_overlap = strategy.hybrid_configs[
'sharding_configs'
].comm_overlap
pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap
if self.tensor_fusion:
if self.tensor_fusion or self.comm_overlap:
assert (
not pp_overlap
), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time."
self._use_main_grad = hasattr(self._parameter_list[0], "main_grad")
self._rank2decay = {}
self._rank2fused = {}
self._comm_buffers = []
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()
......@@ -95,25 +106,22 @@ class DygraphShardingOptimizer:
'_param_groups', self._rank2params[self._sharding_rank]
)
else:
self._use_main_grad = hasattr(self._parameter_list[0], "main_grad")
self._rank2decay = {}
self._rank2fused = {}
self._tensor_fusion()
decay_params = [
p.name for p in self._rank2decay[self._sharding_rank]
]
all_params = self._rank2fused[self._sharding_rank]
fused_params = self._rank2fused[self._sharding_rank]
apply_decay_param_fun = lambda x: x in decay_params
params = []
all_fused_params = []
for v in self._rank2fused.values():
params += v
self._parameter_list = params
self._param_groups = params
all_fused_params += v
self._parameter_list = all_fused_params
self._param_groups = all_fused_params
self._set_inner_opt_attr('_parameter_list', all_params)
self._set_inner_opt_attr('_param_groups', all_params)
self._set_inner_opt_attr('_parameter_list', fused_params)
self._set_inner_opt_attr('_param_groups', fused_params)
origin_decay_param_fun = getattr(
self._inner_opt, '_apply_decay_param_fun', None
)
......@@ -145,11 +153,23 @@ class DygraphShardingOptimizer:
p.clear_gradient(set_to_zero)
def _tensor_fusion(self):
comm_group = self._hcg.get_sharding_parallel_group()
for i in range(self._sharding_world_size):
params = self._rank2params[i]
decay_fused, all_fused = fused_parameters(
params, self._use_main_grad
dst = comm_group.ranks[i]
# TODO(sharding dev): make scale_after_comm a field to be configured by user
decay_fused, all_fused, all_buffer = fused_parameters(
params,
use_main_grad=self._use_main_grad,
fuse_param=True,
comm_overlap=self.comm_overlap,
comm_group=comm_group,
dst=dst,
acc_step=self.accumulate_steps,
scale_after_comm=False,
)
if self.comm_overlap:
self._comm_buffers += all_buffer
self._rank2decay[i] = decay_fused
self._rank2fused[i] = all_fused
for p in all_fused:
......@@ -199,6 +219,10 @@ class DygraphShardingOptimizer:
def reduce_gradients(self, parameter_list, hcg):
# TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync")
if self.comm_overlap:
for buffer in self._comm_buffers:
buffer.scale_grads()
return
with framework.no_grad():
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
......
......@@ -37,11 +37,11 @@ else:
from .pp_utils import p2p_communication as p2p
from paddle.distributed.fleet.utils.tensor_fusion_helper import (
HOOK_ACTION,
FusedCommBuffer,
assign_group_by_size,
)
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer
__all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
......@@ -334,9 +334,11 @@ class PipelineParallel(MetaParallelBase):
for dst in fused_parameter_group:
parameter_list = fused_parameter_group[dst]
if not dp:
if act != HOOK_ACTION.ALL_REDUCE:
# parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst]
else:
dst = -1
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
......@@ -515,7 +517,7 @@ class PipelineParallel(MetaParallelBase):
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
buffer.scale_grads()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
......@@ -1256,7 +1258,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
buffer.scale_grads()
if static_scheduler:
self._reset_counter()
......
......@@ -15,19 +15,10 @@
import paddle
from paddle import _legacy_C_ops
from paddle.distributed.fleet.utils.tensor_fusion_helper import (
flatten_dense_tensors,
)
from paddle.framework import base as imperative_base
__all__ = []
class HOOK_ACTION:
ALL_REDUCE = 0
REDUCE = 1
FLOAT_TYPE_DICT = {
paddle.float16: "float16",
paddle.float32: "float32",
......@@ -116,118 +107,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
'nranks',
nranks,
)
class FusedCommBuffer:
def __init__(self, id, params, comm_group, acc_steps=1, act=None, dst=-1):
self._id = id
self._params = params
self._acc_steps = acc_steps
self._comm_group = comm_group
self.use_main_grad = hasattr(self._params[0], "main_grad")
self._task = None
self._params_step_dict = {}
self._params_checked_in = 0
self._params_to_addr = {}
self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
assert dst == -1
elif self._act == HOOK_ACTION.REDUCE:
assert dst != -1
else:
raise ValueError(
"The act should be allreudce for dp or reduce for sharding."
)
self._dst = dst
self._init_step_dict()
self.grad_storage = flatten_dense_tensors(
self._params,
use_main_grad=self.use_main_grad,
fuse_param=False,
warp_buffer=False,
).buffer
self._record_addr()
def _record_addr(self):
for param in self._params:
addr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
self._params_to_addr[param.name] = addr
def _init_step_dict(self):
for p in self._params:
self._params_step_dict[p.name] = 0
def _reset_params_checked_in(self):
self._task = None
self._init_step_dict()
self._params_checked_in = 0
@property
def _all_params_checked_in(self):
return (
len(self._params) == self._params_checked_in
and len(self._params_step_dict) == 0
)
def add_grad(self, param):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
if self._params_to_addr[param.name] != current_ptr:
raise ValueError(
"The address of the grad/main_grad of the param has been changed during training, "
"which is not allowed for dp/sharding overlap with pp. "
"This may be caused by some non-inplace operations on the grad/main_grad. "
"Please use the inplace version of the operations or disable the overlapping."
)
self._params_step_dict[param.name] += 1
if self._params_step_dict[param.name] == self._acc_steps:
self._params_checked_in += 1
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._comm_grads()
@imperative_base.no_grad
def _comm_grads(self):
assert self._all_params_checked_in
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._task = task
@imperative_base.no_grad
def scale_and_split_grads(self):
assert self._task is not None
self._task.wait()
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
self._reset_params_checked_in()
......@@ -12,13 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
from collections import OrderedDict
import numpy as np
import paddle
from paddle.framework import base as imperative_base
from paddle.framework import core
class HOOK_ACTION:
ALL_REDUCE = 0
REDUCE = 1
alignment = {
"gpu": 256,
}
......@@ -101,23 +109,204 @@ def flatten_dense_tensors(
return grad_storage
def obtain_storage(parameters, use_main_grad, clip, dist):
def bw_hook_func(buffer, param):
@paddle.autograd.no_grad()
def fused_comm(*_):
buffer.add_grad(param)
return fused_comm
class FusedCommBuffer:
def __init__(
self,
id,
params,
comm_group,
acc_steps=1,
act=None,
dst=-1,
use_main_grad=None,
fuse_param=False,
scale_after_comm=True,
):
self._id = id
self._params = params
self._acc_steps = acc_steps
self._comm_group = comm_group
self._scale_after_comm = scale_after_comm
self._fuse_param = fuse_param
self.use_main_grad = (
use_main_grad
if use_main_grad is not None
else hasattr(self._params[0], "main_grad")
)
self._task = None
self._params_step_dict = {}
self._params_checked_in = 0
self._grads_to_addr = {}
self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
assert dst == -1
elif self._act == HOOK_ACTION.REDUCE:
assert dst != -1
else:
raise ValueError(
"The act should be allreudce for dp or reduce for sharding."
)
self._dst = dst
self._init_step_dict()
if self._fuse_param:
self.param_storage, self.grad_storage = flatten_dense_tensors(
self._params,
use_main_grad=use_main_grad,
fuse_param=True,
warp_buffer=True,
)
self.param_storage = self.param_storage.buffer
self.grad_storage = self.grad_storage.buffer
else:
self.param_storage = None
self.grad_storage = flatten_dense_tensors(
self._params,
use_main_grad=self.use_main_grad,
fuse_param=False,
warp_buffer=False,
).buffer
self._record_addr()
def _record_addr(self):
for param in self._params:
addr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
self._grads_to_addr[param.name] = addr
def _init_step_dict(self):
for p in self._params:
self._params_step_dict[p.name] = 0
def _reset_params_checked_in(self):
self._task = None
self._init_step_dict()
self._params_checked_in = 0
@property
def _all_params_checked_in(self):
return (
len(self._params) == self._params_checked_in
and len(self._params_step_dict) == 0
)
def add_grad(self, param):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
if self.use_main_grad
else param.grad.data_ptr()
)
if self._grads_to_addr[param.name] != current_ptr:
raise ValueError(
"The address of the grad/main_grad of the param has been changed during training, "
"which is not allowed for dp/sharding overlap with pp. "
"This may be caused by some non-inplace operations on the grad/main_grad. "
"Please use the inplace version of the operations or disable the overlapping."
)
self._params_step_dict[param.name] += 1
if self._params_step_dict[param.name] == self._acc_steps:
self._params_checked_in += 1
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._comm_grads()
@imperative_base.no_grad
def _comm_grads(self):
assert self._all_params_checked_in
if not self._scale_after_comm:
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)
self._task = task
@imperative_base.no_grad
def scale_grads(self):
assert self._task is not None
self._task.wait()
if self._scale_after_comm:
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
self._reset_params_checked_in()
def obtain_storage(
parameters,
use_main_grad=False,
clip=True,
dist=False,
fuse_param=True,
comm_overlap=False,
act=None,
comm_group=None,
dst=-1,
acc_steps=1,
scale_after_comm=False,
):
if len(parameters) < 1:
return []
return [], []
var_groups = assign_group_by_size(parameters, group_size=256 * 1024 * 1024)
storage = []
buffers = []
for group_idx, parameters in var_groups.items():
param_storage, grad_storage = flatten_dense_tensors(
comm_buffer = FusedCommBuffer(
group_idx,
parameters,
comm_group=comm_group,
acc_steps=acc_steps,
act=act,
dst=dst,
use_main_grad=use_main_grad,
fuse_param=True,
warp_buffer=True,
fuse_param=fuse_param,
scale_after_comm=scale_after_comm,
)
param_storage.buffer.need_clip = clip
param_storage.buffer.is_distributed = dist
storage.append(param_storage.buffer)
return storage
if fuse_param:
param_buffer = comm_buffer.param_storage
param_buffer.need_clip = clip
param_buffer.is_distributed = dist
storage.append(param_buffer)
if comm_overlap:
for param in parameters:
param._register_backward_hook(bw_hook_func(comm_buffer, param))
buffers.append(comm_buffer)
return storage, buffers
def filter_params(params, is_fp32, is_distributed, need_clip):
......@@ -155,7 +344,38 @@ def filter_params(params, is_fp32, is_distributed, need_clip):
return params, dtype
def fused_parameters(parameters, use_main_grad):
def fused_parameters(
parameters,
use_main_grad=False,
fuse_param=True,
comm_overlap=False,
comm_group=None,
dst=-1,
acc_step=1,
scale_after_comm=False,
):
"""
Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled.
:param parameters: all parameters to be fused.
:param use_main_grad: does the gradient use main grad or not
:param comm_overlap: enable comm overlap or not
:param comm_group: the comm group for comm overlap
:param dst: the dst for comm overlap
:param acc_step: acc steps, using for comm overlap
:param fuse_param: fuse param or not
:param scale_after_comm: if enable comm overlap, specify the location of grad scale
:return: param storage if fused, comm buffers is comm overlap
"""
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
act = (
HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE
)
if comm_overlap:
assert comm_group is not None
if act == HOOK_ACTION.REDUCE:
assert dst != -1
elif act == HOOK_ACTION.ALL_REDUCE:
dst = -1
param_groups = []
attrs = []
......@@ -178,6 +398,7 @@ def fused_parameters(parameters, use_main_grad):
decay_fused = []
all_fused = []
all_buffers = []
for params, attr in zip(param_groups, attrs):
decay_params = []
other_params = []
......@@ -190,14 +411,36 @@ def fused_parameters(parameters, use_main_grad):
is_distributed = attr[1]
need_clip = attr[2]
decay = obtain_storage(
decay_params, use_main_grad, need_clip, is_distributed
decay, decay_buffers = obtain_storage(
decay_params,
use_main_grad=use_main_grad,
clip=need_clip,
dist=is_distributed,
fuse_param=fuse_param,
comm_overlap=comm_overlap,
act=act,
comm_group=comm_group,
dst=dst,
acc_steps=acc_step,
scale_after_comm=scale_after_comm,
)
other = obtain_storage(
other_params, use_main_grad, need_clip, is_distributed
other, other_buffers = obtain_storage(
other_params,
fuse_param=fuse_param,
comm_overlap=comm_overlap,
use_main_grad=use_main_grad,
clip=need_clip,
dist=is_distributed,
act=act,
comm_group=comm_group,
dst=dst,
acc_steps=acc_step,
scale_after_comm=scale_after_comm,
)
decay_fused += decay
all_fused += decay
all_fused += other
all_buffers += decay_buffers
all_buffers += other_buffers
return decay_fused, all_fused
return decay_fused, all_fused, all_buffers
......@@ -99,6 +99,8 @@ class TestDistSharding(unittest.TestCase):
"pp_degree": 1,
}
self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True
self.strategy.hybrid_configs["sharding_configs"].comm_overlap = True
self.strategy.hybrid_configs["sharding_configs"].accumulate_steps = 1
fleet.init(is_collective=True, strategy=self.strategy)
self.data = np.random.randint(
0,
......
......@@ -15,7 +15,7 @@
import unittest
import paddle
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import (
from paddle.distributed.fleet.utils.tensor_fusion_helper import (
HOOK_ACTION,
FusedCommBuffer,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册