未验证 提交 79b261ba 编写于 作者: W wuhuachaocoding 提交者: GitHub

solve share params bugs and add exclude_layer attr for stage3. (#48695)

上级 3f2f036c
...@@ -86,6 +86,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -86,6 +86,7 @@ class GroupShardedStage3(nn.Layer):
offload=False, offload=False,
sync_comm=False, sync_comm=False,
dp_group=None, dp_group=None,
exclude_layer=None,
): ):
super().__init__() super().__init__()
...@@ -96,6 +97,14 @@ class GroupShardedStage3(nn.Layer): ...@@ -96,6 +97,14 @@ class GroupShardedStage3(nn.Layer):
self.__sync_buffers = sync_buffers self.__sync_buffers = sync_buffers
self._offload = offload self._offload = offload
self._sync_comm = sync_comm self._sync_comm = sync_comm
# stage3 support some layer set by users to be unslice
# _exclude_layer=[layer_name or id(layer)]
self._exclude_layer = [] if exclude_layer is None else exclude_layer
assert isinstance(
self._exclude_layer, (list, tuple)
), "the exclude_layers must be a list with layers' name or layers' id"
# segmentation size # segmentation size
assert segment_size >= 0, "segment_size must be GE than 0." assert segment_size >= 0, "segment_size must be GE than 0."
self._segment_size = segment_size self._segment_size = segment_size
...@@ -350,6 +359,19 @@ class GroupShardedStage3(nn.Layer): ...@@ -350,6 +359,19 @@ class GroupShardedStage3(nn.Layer):
Parameter segmentation and memory integration. Parameter segmentation and memory integration.
""" """
if id(layer) in self._trainable_params.keys():
return
# the layer in self._exclude_layer will be unsliced.
if (
id(layer) in self._exclude_layer
or layer.__class__.__name__ in self._exclude_layer
):
for p in current_layer_params:
if p.trainable:
self._unslice_params.add(_UnsliceParam(p))
return
def _add_manage_info(trainable_param): def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param) return _PartitionParam(trainable_param)
...@@ -360,7 +382,6 @@ class GroupShardedStage3(nn.Layer): ...@@ -360,7 +382,6 @@ class GroupShardedStage3(nn.Layer):
elif p.trainable: elif p.trainable:
self._unslice_params.add(_UnsliceParam(p)) self._unslice_params.add(_UnsliceParam(p))
assert id(layer) not in self._trainable_params.keys()
self._trainable_params[id(layer)] = current_params self._trainable_params[id(layer)] = current_params
for param in self._trainable_params[id(layer)]: for param in self._trainable_params[id(layer)]:
...@@ -463,7 +484,12 @@ class GroupShardedStage3(nn.Layer): ...@@ -463,7 +484,12 @@ class GroupShardedStage3(nn.Layer):
""" """
current_layer_params = _current_layer_params(layer) current_layer_params = _current_layer_params(layer)
if current_layer_params: if current_layer_params:
self._register_forward_all_hooks(layer, self._task_flow) # the layer in self._exclude_layer will be added hooks.
if not (
id(layer) in self._exclude_layer
or layer.__class__.__name__ in self._exclude_layer
):
self._register_forward_all_hooks(layer, self._task_flow)
for _, sub_layer in layer.named_children(): for _, sub_layer in layer.named_children():
self._register_forward_hooks(sub_layer) self._register_forward_hooks(sub_layer)
......
...@@ -46,6 +46,7 @@ def group_sharded_parallel( ...@@ -46,6 +46,7 @@ def group_sharded_parallel(
segment_size=2**20, segment_size=2**20,
sync_comm=False, sync_comm=False,
dp_group=None, dp_group=None,
exclude_layer=None,
): ):
""" """
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation. Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
...@@ -63,6 +64,7 @@ def group_sharded_parallel( ...@@ -63,6 +64,7 @@ def group_sharded_parallel(
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20. segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used. sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication. dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication.
exclude_layer(list, optional): exclude some layers for slicing for sharding stage3, for example, exclude_layer=["GroupNorm", id(model.gpt.linear)], exclude_layer must contain the layers' name or one layer's id.
Returns: Returns:
model: A wrapper for group sharded given model. model: A wrapper for group sharded given model.
...@@ -159,6 +161,7 @@ def group_sharded_parallel( ...@@ -159,6 +161,7 @@ def group_sharded_parallel(
sync_comm=sync_comm, sync_comm=sync_comm,
dp_group=dp_group, dp_group=dp_group,
device=device, device=device,
exclude_layer=exclude_layer,
) )
else: else:
raise ValueError("Please enter the correct level.") raise ValueError("Please enter the correct level.")
......
...@@ -21,7 +21,6 @@ import tempfile ...@@ -21,7 +21,6 @@ import tempfile
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
GroupShardedOptimizerStage2, GroupShardedOptimizerStage2,
) )
...@@ -44,7 +43,7 @@ momentum_rate = 0.9 ...@@ -44,7 +43,7 @@ momentum_rate = 0.9
l2_decay = 1e-4 l2_decay = 1e-4
class MLP(fluid.Layer): class MLP(paddle.nn.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super().__init__() super().__init__()
...@@ -59,6 +58,50 @@ class MLP(fluid.Layer): ...@@ -59,6 +58,50 @@ class MLP(fluid.Layer):
return y return y
class Encoder(paddle.nn.Layer):
def __init__(self, encoder):
super(Encoder, self).__init__()
self.first_stage = paddle.nn.Linear(1024, 1024)
self.encoder = encoder
def forward(self, x):
x = self.encoder(x)
x = self.first_stage(x)
return x
class Decoder(paddle.nn.Layer):
def __init__(self, decoder):
super(Decoder, self).__init__()
self.decoder = decoder
self.final_stage = paddle.nn.Linear(1024, 1024)
self.group_norm = paddle.nn.GroupNorm(64, 1024)
def forward(self, x):
x = self.final_stage(x)
x = self.decoder(x)
x = self.group_norm(x)
return x
class SpecialModel(paddle.nn.Layer):
def __init__(self):
super(SpecialModel, self).__init__()
self.shared = paddle.nn.Linear(1024, 1024, bias_attr=False)
self.encoder = Encoder(self.shared)
self.decoder = Decoder(self.shared)
self.final_stage = paddle.nn.Linear(1024, 2, bias_attr=False)
self.extra_parameters = [self.shared.weight]
def forward(self, x):
x = self.shared(x)
x = self.encoder(x)
x = self.decoder(x)
x = self.final_stage(x)
return x
def reader_decorator(linear_size=1000): def reader_decorator(linear_size=1000):
def __reader__(): def __reader__():
for _ in range(100): for _ in range(100):
...@@ -91,9 +134,11 @@ def train_mlp( ...@@ -91,9 +134,11 @@ def train_mlp(
accumulate_grad=False, accumulate_grad=False,
batch_size=100, batch_size=100,
opt_group=False, opt_group=False,
linear_size=1000,
sync_comm=False, sync_comm=False,
test_minimize=False, test_minimize=False,
save_model=False, save_model=False,
exclude_test=[],
): ):
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
if opt_group: if opt_group:
...@@ -123,6 +168,7 @@ def train_mlp( ...@@ -123,6 +168,7 @@ def train_mlp(
group=group, group=group,
sync_comm=sync_comm, sync_comm=sync_comm,
segment_size=2**15, segment_size=2**15,
exclude_layer=exclude_test,
) )
# check optimizer.minimize() error # check optimizer.minimize() error
...@@ -136,7 +182,9 @@ def train_mlp( ...@@ -136,7 +182,9 @@ def train_mlp(
return return
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True reader_decorator(linear_size=linear_size),
batch_size=batch_size,
drop_last=True,
) )
train_loader = paddle.io.DataLoader.from_generator( train_loader = paddle.io.DataLoader.from_generator(
...@@ -154,7 +202,7 @@ def train_mlp( ...@@ -154,7 +202,7 @@ def train_mlp(
img, label = data img, label = data
label.stop_gradient = True label.stop_gradient = True
img.stop_gradient = True img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'): with paddle.amp.auto_cast(use_pure_fp16, level='O2'):
out = model(img) out = model(img)
loss = paddle.nn.functional.cross_entropy( loss = paddle.nn.functional.cross_entropy(
input=out, label=label input=out, label=label
...@@ -289,6 +337,66 @@ def test_stage2_stage3(): ...@@ -289,6 +337,66 @@ def test_stage2_stage3():
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6 stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6
) )
# test for share layer parameters and exclude_layer function.
sm1, sm2, sm3, sm4 = (
SpecialModel(),
SpecialModel(),
SpecialModel(),
SpecialModel(),
)
st_dict = sm1.state_dict()
sm2.set_state_dict(st_dict)
sm3.set_state_dict(st_dict)
sm4.set_state_dict(st_dict)
# fp16 for special model
stage2_params = train_mlp(
sm1,
sharding_stage=2,
use_pure_fp16=True,
opt_group=False,
linear_size=1024,
)
stage3_params = train_mlp(
sm2,
sharding_stage=3,
use_pure_fp16=True,
opt_group=False,
linear_size=1024,
exclude_test=["GroupNorm"],
)
for i in range(len(stage2_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3,
)
# fp32 for special model
stage2_params = train_mlp(
sm3,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
linear_size=1024,
)
stage3_params = train_mlp(
sm4,
sharding_stage=3,
use_pure_fp16=False,
opt_group=False,
linear_size=1024,
exclude_test=[id(sm4.decoder.group_norm)],
)
for i in range(len(stage2_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-6,
atol=1e-4,
)
# save/load model # save/load model
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel") model_file = os.path.join(output_dir, "model.pdmodel")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册