未验证 提交 79b261ba 编写于 作者: W wuhuachaocoding 提交者: GitHub

solve share params bugs and add exclude_layer attr for stage3. (#48695)

上级 3f2f036c
......@@ -86,6 +86,7 @@ class GroupShardedStage3(nn.Layer):
offload=False,
sync_comm=False,
dp_group=None,
exclude_layer=None,
):
super().__init__()
......@@ -96,6 +97,14 @@ class GroupShardedStage3(nn.Layer):
self.__sync_buffers = sync_buffers
self._offload = offload
self._sync_comm = sync_comm
# stage3 support some layer set by users to be unslice
# _exclude_layer=[layer_name or id(layer)]
self._exclude_layer = [] if exclude_layer is None else exclude_layer
assert isinstance(
self._exclude_layer, (list, tuple)
), "the exclude_layers must be a list with layers' name or layers' id"
# segmentation size
assert segment_size >= 0, "segment_size must be GE than 0."
self._segment_size = segment_size
......@@ -350,6 +359,19 @@ class GroupShardedStage3(nn.Layer):
Parameter segmentation and memory integration.
"""
if id(layer) in self._trainable_params.keys():
return
# the layer in self._exclude_layer will be unsliced.
if (
id(layer) in self._exclude_layer
or layer.__class__.__name__ in self._exclude_layer
):
for p in current_layer_params:
if p.trainable:
self._unslice_params.add(_UnsliceParam(p))
return
def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param)
......@@ -360,7 +382,6 @@ class GroupShardedStage3(nn.Layer):
elif p.trainable:
self._unslice_params.add(_UnsliceParam(p))
assert id(layer) not in self._trainable_params.keys()
self._trainable_params[id(layer)] = current_params
for param in self._trainable_params[id(layer)]:
......@@ -463,6 +484,11 @@ class GroupShardedStage3(nn.Layer):
"""
current_layer_params = _current_layer_params(layer)
if current_layer_params:
# the layer in self._exclude_layer will be added hooks.
if not (
id(layer) in self._exclude_layer
or layer.__class__.__name__ in self._exclude_layer
):
self._register_forward_all_hooks(layer, self._task_flow)
for _, sub_layer in layer.named_children():
......
......@@ -46,6 +46,7 @@ def group_sharded_parallel(
segment_size=2**20,
sync_comm=False,
dp_group=None,
exclude_layer=None,
):
"""
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
......@@ -63,6 +64,7 @@ def group_sharded_parallel(
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication.
exclude_layer(list, optional): exclude some layers for slicing for sharding stage3, for example, exclude_layer=["GroupNorm", id(model.gpt.linear)], exclude_layer must contain the layers' name or one layer's id.
Returns:
model: A wrapper for group sharded given model.
......@@ -159,6 +161,7 @@ def group_sharded_parallel(
sync_comm=sync_comm,
dp_group=dp_group,
device=device,
exclude_layer=exclude_layer,
)
else:
raise ValueError("Please enter the correct level.")
......
......@@ -21,7 +21,6 @@ import tempfile
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
GroupShardedOptimizerStage2,
)
......@@ -44,7 +43,7 @@ momentum_rate = 0.9
l2_decay = 1e-4
class MLP(fluid.Layer):
class MLP(paddle.nn.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super().__init__()
......@@ -59,6 +58,50 @@ class MLP(fluid.Layer):
return y
class Encoder(paddle.nn.Layer):
def __init__(self, encoder):
super(Encoder, self).__init__()
self.first_stage = paddle.nn.Linear(1024, 1024)
self.encoder = encoder
def forward(self, x):
x = self.encoder(x)
x = self.first_stage(x)
return x
class Decoder(paddle.nn.Layer):
def __init__(self, decoder):
super(Decoder, self).__init__()
self.decoder = decoder
self.final_stage = paddle.nn.Linear(1024, 1024)
self.group_norm = paddle.nn.GroupNorm(64, 1024)
def forward(self, x):
x = self.final_stage(x)
x = self.decoder(x)
x = self.group_norm(x)
return x
class SpecialModel(paddle.nn.Layer):
def __init__(self):
super(SpecialModel, self).__init__()
self.shared = paddle.nn.Linear(1024, 1024, bias_attr=False)
self.encoder = Encoder(self.shared)
self.decoder = Decoder(self.shared)
self.final_stage = paddle.nn.Linear(1024, 2, bias_attr=False)
self.extra_parameters = [self.shared.weight]
def forward(self, x):
x = self.shared(x)
x = self.encoder(x)
x = self.decoder(x)
x = self.final_stage(x)
return x
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
......@@ -91,9 +134,11 @@ def train_mlp(
accumulate_grad=False,
batch_size=100,
opt_group=False,
linear_size=1000,
sync_comm=False,
test_minimize=False,
save_model=False,
exclude_test=[],
):
group = paddle.distributed.new_group([0, 1])
if opt_group:
......@@ -123,6 +168,7 @@ def train_mlp(
group=group,
sync_comm=sync_comm,
segment_size=2**15,
exclude_layer=exclude_test,
)
# check optimizer.minimize() error
......@@ -136,7 +182,9 @@ def train_mlp(
return
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True
reader_decorator(linear_size=linear_size),
batch_size=batch_size,
drop_last=True,
)
train_loader = paddle.io.DataLoader.from_generator(
......@@ -154,7 +202,7 @@ def train_mlp(
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
with paddle.amp.auto_cast(use_pure_fp16, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label
......@@ -289,6 +337,66 @@ def test_stage2_stage3():
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6
)
# test for share layer parameters and exclude_layer function.
sm1, sm2, sm3, sm4 = (
SpecialModel(),
SpecialModel(),
SpecialModel(),
SpecialModel(),
)
st_dict = sm1.state_dict()
sm2.set_state_dict(st_dict)
sm3.set_state_dict(st_dict)
sm4.set_state_dict(st_dict)
# fp16 for special model
stage2_params = train_mlp(
sm1,
sharding_stage=2,
use_pure_fp16=True,
opt_group=False,
linear_size=1024,
)
stage3_params = train_mlp(
sm2,
sharding_stage=3,
use_pure_fp16=True,
opt_group=False,
linear_size=1024,
exclude_test=["GroupNorm"],
)
for i in range(len(stage2_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3,
)
# fp32 for special model
stage2_params = train_mlp(
sm3,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
linear_size=1024,
)
stage3_params = train_mlp(
sm4,
sharding_stage=3,
use_pure_fp16=False,
opt_group=False,
linear_size=1024,
exclude_test=[id(sm4.decoder.group_norm)],
)
for i in range(len(stage2_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-6,
atol=1e-4,
)
# save/load model
output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册