未验证 提交 4c4d3185 编写于 作者: Y Yuang Liu 提交者: GitHub

Sharding stage 1 tensor fusion (#55427)

上级 f7cbfc4c
......@@ -66,6 +66,10 @@ message PpConfig {
optional bool profiling = 5 [ default = false ];
}
message DygraphShardingConfig {
optional bool tensor_fusion = 1 [ default = false ];
}
message HybridConfig {
optional int32 dp_degree = 1 [ default = -1 ];
optional int32 mp_degree = 2 [ default = 1 ];
......@@ -73,6 +77,7 @@ message HybridConfig {
optional int32 sharding_degree = 4 [ default = 1 ];
optional MpConfig mp_configs = 5;
optional PpConfig pp_configs = 6;
optional DygraphShardingConfig sharding_configs = 7;
}
message AMPConfig {
......
......@@ -18,8 +18,10 @@ from functools import reduce
import paddle
from paddle import framework
from paddle.distributed import fleet
from ...utils.log_util import logger
from ...utils.tensor_fusion_helper import fused_parameters
def _is_trainable(param):
......@@ -62,15 +64,53 @@ class DygraphShardingOptimizer:
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank()
strategy = fleet.fleet._user_defined_strategy
self.tensor_fusion = strategy.hybrid_configs[
'sharding_configs'
].tensor_fusion
pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap
if self.tensor_fusion:
assert (
not pp_overlap
), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time."
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
if not self.tensor_fusion:
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
else:
self._use_main_grad = hasattr(self._parameter_list[0], "main_grad")
self._rank2decay = {}
self._rank2fused = {}
self._tensor_fusion()
decay_params = [
p.name for p in self._rank2decay[self._sharding_rank]
]
all_params = self._rank2fused[self._sharding_rank]
apply_decay_param_fun = lambda x: x in decay_params
params = []
for v in self._rank2fused.values():
params += v
self._parameter_list = params
self._param_groups = params
self._set_inner_opt_attr('_parameter_list', all_params)
self._set_inner_opt_attr('_param_groups', all_params)
origin_decay_param_fun = getattr(
self._inner_opt, '_apply_decay_param_fun', None
)
if origin_decay_param_fun is not None:
self._set_inner_opt_attr(
'_apply_decay_param_fun', apply_decay_param_fun
)
def clear_grad(self, set_to_zero=True):
"""
......@@ -85,7 +125,25 @@ class DygraphShardingOptimizer:
p.main_grad._clear()
p.main_grad = None
elif not hasattr(p, "main_grad"):
p.clear_gradient(set_to_zero)
if self.tensor_fusion:
if set_to_zero:
p.grad.zero_()
else:
p.grad._clear()
p.grad = None
else:
p.clear_gradient(set_to_zero)
def _tensor_fusion(self):
for i in range(self._sharding_world_size):
params = self._rank2params[i]
decay_fused, all_fused = fused_parameters(
params, self._use_main_grad
)
self._rank2decay[i] = decay_fused
self._rank2fused[i] = all_fused
for p in all_fused:
self._param2rank[p.name] = i
def _partition_parameters(self):
"""
......@@ -167,7 +225,12 @@ class DygraphShardingOptimizer:
logger.debug("sharding start sync parameters")
with framework.no_grad():
# TODO detach not need (?)
for rank, params in self._rank2params.items():
valid_rank_to_params = (
self._rank2params
if not self.tensor_fusion
else self._rank2fused
)
for rank, params in valid_rank_to_params.items():
for param in params:
paddle.distributed.broadcast(
param,
......@@ -236,9 +299,12 @@ class DygraphShardingOptimizer:
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)
update_param_names = [
p.name for p in self._rank2params[self._sharding_rank]
]
rank_params = (
self._rank2params[self._sharding_rank]
if not self.tensor_fusion
else self._rank2fused[self._sharding_rank]
)
update_param_names = [p.name for p in rank_params]
update_params_grads = [
(p, g) for p, g in params_grads if p.name in update_param_names
]
......
......@@ -30,6 +30,14 @@ from paddle.framework import core
from .group_sharded_utils import Type, cvt_to_device, device_guard
class BufferWarper(core.eager.Tensor):
def __init__(self):
super().__init__()
self.need_clip = True
self.is_distributed = False
self.trainable = True
class InternalStorage:
"""
This is a basic class, which is responsible for consolidating the basic storage tensor.
......@@ -97,6 +105,12 @@ class InternalStorage:
self.buffer = self.buffer.cast(dtype=dtype)
self._dtype = dtype
def warp_buffer(self):
tmp_buffer = BufferWarper()
self._buffer = self.buffer
tmp_buffer.get_tensor()._share_data_with(self.buffer.get_tensor())
self.buffer = tmp_buffer
class ParamStorage(InternalStorage):
"""
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from collections import OrderedDict
import numpy as np
import paddle
from paddle.framework import core
alignment = {
"gpu": 256,
}
align = {
paddle.float16.value: 2,
paddle.bfloat16.value: 2,
paddle.float32.value: 4,
}
def assign_group_by_size(parameters, group_size=256 * 1024 * 1024):
# TODO(Yuang Liu): make pp_utils/utils use this tensor fusion helper
is_sparse_gradient = [False] * len(parameters)
group_indices = core.eager_assign_group_by_size(
parameters, is_sparse_gradient, [group_size, group_size]
)
var_groups = OrderedDict()
for group_idx, indices in enumerate(group_indices):
for index in indices:
var_groups.setdefault(group_idx, []).append(parameters[index])
return var_groups
def flatten_dense_tensors(parameters, use_main_grad):
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import (
GradStorage,
ParamStorage,
)
_buffer_size = 0
_param2align = {}
dtype = parameters[0].dtype
for param in parameters:
assert param.trainable, "param must be trainable..."
size = np.prod(param.shape) * align[dtype]
remaining = size % alignment["gpu"]
ali = 0 if remaining == 0 else alignment["gpu"] - remaining
align_ = ali // align[dtype]
_buffer_size += np.prod(param.shape) + align_
_param2align[param.name] = align_
param_storage = ParamStorage(size=_buffer_size, dtype=dtype, device="gpu")
param_storage.add_rank_params(parameters, _param2align)
# process gradient
grad_dtype = paddle.float32 if use_main_grad else dtype
grad_storage = GradStorage(
size=_buffer_size,
dtype=grad_dtype,
device="gpu",
destination="0",
parm2align=_param2align,
)
for param in parameters:
grad_storage.add_grad(param, _param2align[param.name])
param_storage.warp_buffer()
grad_storage.warp_buffer()
if not use_main_grad:
# param_storage --> grad_storage
param_storage.buffer._copy_gradient_from(grad_storage.buffer)
else:
param_storage.buffer.main_grad = grad_storage.buffer
param_storage.buffer.stop_gradient = False
return param_storage, grad_storage
def obtain_storage(parameters, use_main_grad, clip, dist):
if len(parameters) < 1:
return []
var_groups = assign_group_by_size(parameters)
storage = []
for group_idx, parameters in var_groups.items():
param_storage, grad_storage = flatten_dense_tensors(
parameters, use_main_grad
)
param_storage.buffer.need_clip = clip
param_storage.buffer.is_distributed = dist
storage.append(param_storage.buffer)
return storage
def filter_params(params, is_fp32, is_distributed, need_clip):
params = list(
filter(
lambda x: x.is_distributed
if is_distributed
else (not x.is_distributed),
params,
)
)
params = list(
filter(
lambda x: getattr(x, 'need_clip', True)
if need_clip
else (not getattr(x, 'need_clip', True)),
params,
)
)
params = list(
filter(
lambda x: x.dtype == paddle.float32
if is_fp32
else x.dtype != paddle.float32,
params,
)
)
dtype = None
for p in params:
if dtype is None:
dtype = p.dtype
else:
assert dtype == p.dtype
return params, dtype
def fused_parameters(parameters, use_main_grad):
param_groups = []
attrs = []
is_fp32 = [True, False]
is_distributed = [True, False]
need_clip = [True, False]
no_fp32_dtype = None
for fp32, dist, clip in itertools.product(
is_fp32, is_distributed, need_clip
):
params, dtype = filter_params(parameters, fp32, dist, clip)
if not fp32:
if no_fp32_dtype is None:
no_fp32_dtype = dtype
elif dtype is not None:
assert no_fp32_dtype == dtype
attrs.append([dtype, dist, clip])
param_groups.append(params)
decay_fused = []
all_fused = []
for params, attr in zip(param_groups, attrs):
decay_params = []
other_params = []
for param in params:
if not any(nd in param.name for nd in ["bias", "norm", "b_0"]):
decay_params.append(param)
else:
other_params.append(param)
is_distributed = attr[1]
need_clip = attr[2]
decay = obtain_storage(
decay_params, use_main_grad, need_clip, is_distributed
)
other = obtain_storage(
other_params, use_main_grad, need_clip, is_distributed
)
decay_fused += decay
all_fused += decay
all_fused += other
return decay_fused, all_fused
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10
class SimpleDPNet(paddle.nn.Layer):
def __init__(
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
):
super().__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc1)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc2)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5),
)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
class TestDistSharding(unittest.TestCase):
def setUp(self):
random.seed(2021)
np.random.seed(2021)
paddle.seed(2021)
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True
fleet.init(is_collective=True, strategy=self.strategy)
self.data = np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
if paddle.distributed.get_rank() == 0:
self.batch_sharding = paddle.to_tensor(self.data[:2])
else:
self.batch_sharding = paddle.to_tensor(self.data[2:])
self.batch_single = paddle.to_tensor(self.data)
def train_batch(self, batch, model, optimizer):
output = model(batch)
loss = output.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
return loss
def build_optimizer(self, model):
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.001,
grad_clip=clip,
)
return optimizer
def build_model_optimizer(self):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model_a = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
optimizer_a = self.build_optimizer(model_a)
model_b = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
optimizer_b = self.build_optimizer(model_b)
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
return model_a, optimizer_a, model_b, optimizer_b
def sharding_model(self):
(
model_a,
optimizer_a,
model_b,
optimizer_b,
) = self.build_model_optimizer()
self.assertTrue(
isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer)
)
for idx in range(STEPS):
loss_a = self.train_batch(self.batch_sharding, model_a, optimizer_a)
loss_b = self.train_batch(self.batch_single, model_b, optimizer_b)
np.testing.assert_allclose(loss_a, loss_b, rtol=1e-6, atol=1e-6)
for j in range(len(model_a.parameters())):
np.testing.assert_allclose(
model_a.parameters()[j].numpy(),
model_b.parameters()[j].numpy(),
rtol=1e-6,
atol=1e-7,
)
def test_sharding_adam(self):
self.sharding_model()
if __name__ == "__main__":
unittest.main()
......@@ -22,6 +22,9 @@ class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_sharding_logic(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_tensor_fusion(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion.py')
def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册