diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index 72f4cc5d2a664f9225ee86a3ca36499e2ce774d5..f792e0a53857c5fedb905a907ed5b7201da99bc5 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -86,6 +86,7 @@ class GroupShardedStage3(nn.Layer): offload=False, sync_comm=False, dp_group=None, + exclude_layer=None, ): super().__init__() @@ -96,6 +97,14 @@ class GroupShardedStage3(nn.Layer): self.__sync_buffers = sync_buffers self._offload = offload self._sync_comm = sync_comm + + # stage3 support some layer set by users to be unslice + # _exclude_layer=[layer_name or id(layer)] + self._exclude_layer = [] if exclude_layer is None else exclude_layer + assert isinstance( + self._exclude_layer, (list, tuple) + ), "the exclude_layers must be a list with layers' name or layers' id" + # segmentation size assert segment_size >= 0, "segment_size must be GE than 0." self._segment_size = segment_size @@ -350,6 +359,19 @@ class GroupShardedStage3(nn.Layer): Parameter segmentation and memory integration. """ + if id(layer) in self._trainable_params.keys(): + return + + # the layer in self._exclude_layer will be unsliced. + if ( + id(layer) in self._exclude_layer + or layer.__class__.__name__ in self._exclude_layer + ): + for p in current_layer_params: + if p.trainable: + self._unslice_params.add(_UnsliceParam(p)) + return + def _add_manage_info(trainable_param): return _PartitionParam(trainable_param) @@ -360,7 +382,6 @@ class GroupShardedStage3(nn.Layer): elif p.trainable: self._unslice_params.add(_UnsliceParam(p)) - assert id(layer) not in self._trainable_params.keys() self._trainable_params[id(layer)] = current_params for param in self._trainable_params[id(layer)]: @@ -463,7 +484,12 @@ class GroupShardedStage3(nn.Layer): """ current_layer_params = _current_layer_params(layer) if current_layer_params: - self._register_forward_all_hooks(layer, self._task_flow) + # the layer in self._exclude_layer will be added hooks. + if not ( + id(layer) in self._exclude_layer + or layer.__class__.__name__ in self._exclude_layer + ): + self._register_forward_all_hooks(layer, self._task_flow) for _, sub_layer in layer.named_children(): self._register_forward_hooks(sub_layer) diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py index e8e6c2ebc9ca399461d42f366518f76a06d9051e..0cea2d851ad50a2b6de6dd6edf7efc8f314be397 100644 --- a/python/paddle/distributed/sharding/group_sharded.py +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -46,6 +46,7 @@ def group_sharded_parallel( segment_size=2**20, sync_comm=False, dp_group=None, + exclude_layer=None, ): """ Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation. @@ -63,6 +64,7 @@ def group_sharded_parallel( segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20. sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used. dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication. + exclude_layer(list, optional): exclude some layers for slicing for sharding stage3, for example, exclude_layer=["GroupNorm", id(model.gpt.linear)], exclude_layer must contain the layers' name or one layer's id. Returns: model: A wrapper for group sharded given model. @@ -159,6 +161,7 @@ def group_sharded_parallel( sync_comm=sync_comm, dp_group=dp_group, device=device, + exclude_layer=exclude_layer, ) else: raise ValueError("Please enter the correct level.") diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py index 245d71d3379f59490bd240ba3d712c386baadfaa..84a77855d9464d99a6da1b8197e68ef37f0cac5c 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py @@ -21,7 +21,6 @@ import tempfile import numpy as np import paddle -import paddle.fluid as fluid from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import ( GroupShardedOptimizerStage2, ) @@ -44,7 +43,7 @@ momentum_rate = 0.9 l2_decay = 1e-4 -class MLP(fluid.Layer): +class MLP(paddle.nn.Layer): def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): super().__init__() @@ -59,6 +58,50 @@ class MLP(fluid.Layer): return y +class Encoder(paddle.nn.Layer): + def __init__(self, encoder): + super(Encoder, self).__init__() + self.first_stage = paddle.nn.Linear(1024, 1024) + self.encoder = encoder + + def forward(self, x): + x = self.encoder(x) + x = self.first_stage(x) + return x + + +class Decoder(paddle.nn.Layer): + def __init__(self, decoder): + super(Decoder, self).__init__() + self.decoder = decoder + self.final_stage = paddle.nn.Linear(1024, 1024) + self.group_norm = paddle.nn.GroupNorm(64, 1024) + + def forward(self, x): + x = self.final_stage(x) + x = self.decoder(x) + x = self.group_norm(x) + return x + + +class SpecialModel(paddle.nn.Layer): + def __init__(self): + super(SpecialModel, self).__init__() + self.shared = paddle.nn.Linear(1024, 1024, bias_attr=False) + self.encoder = Encoder(self.shared) + self.decoder = Decoder(self.shared) + self.final_stage = paddle.nn.Linear(1024, 2, bias_attr=False) + + self.extra_parameters = [self.shared.weight] + + def forward(self, x): + x = self.shared(x) + x = self.encoder(x) + x = self.decoder(x) + x = self.final_stage(x) + return x + + def reader_decorator(linear_size=1000): def __reader__(): for _ in range(100): @@ -91,9 +134,11 @@ def train_mlp( accumulate_grad=False, batch_size=100, opt_group=False, + linear_size=1000, sync_comm=False, test_minimize=False, save_model=False, + exclude_test=[], ): group = paddle.distributed.new_group([0, 1]) if opt_group: @@ -123,6 +168,7 @@ def train_mlp( group=group, sync_comm=sync_comm, segment_size=2**15, + exclude_layer=exclude_test, ) # check optimizer.minimize() error @@ -136,7 +182,9 @@ def train_mlp( return train_reader = paddle.batch( - reader_decorator(), batch_size=batch_size, drop_last=True + reader_decorator(linear_size=linear_size), + batch_size=batch_size, + drop_last=True, ) train_loader = paddle.io.DataLoader.from_generator( @@ -154,7 +202,7 @@ def train_mlp( img, label = data label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(True, level='O2'): + with paddle.amp.auto_cast(use_pure_fp16, level='O2'): out = model(img) loss = paddle.nn.functional.cross_entropy( input=out, label=label @@ -289,6 +337,66 @@ def test_stage2_stage3(): stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6 ) + # test for share layer parameters and exclude_layer function. + sm1, sm2, sm3, sm4 = ( + SpecialModel(), + SpecialModel(), + SpecialModel(), + SpecialModel(), + ) + st_dict = sm1.state_dict() + sm2.set_state_dict(st_dict) + sm3.set_state_dict(st_dict) + sm4.set_state_dict(st_dict) + + # fp16 for special model + stage2_params = train_mlp( + sm1, + sharding_stage=2, + use_pure_fp16=True, + opt_group=False, + linear_size=1024, + ) + stage3_params = train_mlp( + sm2, + sharding_stage=3, + use_pure_fp16=True, + opt_group=False, + linear_size=1024, + exclude_test=["GroupNorm"], + ) + for i in range(len(stage2_params)): + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[i].numpy(), + rtol=1e-4, + atol=1e-3, + ) + + # fp32 for special model + stage2_params = train_mlp( + sm3, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False, + linear_size=1024, + ) + stage3_params = train_mlp( + sm4, + sharding_stage=3, + use_pure_fp16=False, + opt_group=False, + linear_size=1024, + exclude_test=[id(sm4.decoder.group_norm)], + ) + for i in range(len(stage2_params)): + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[i].numpy(), + rtol=1e-6, + atol=1e-4, + ) + # save/load model output_dir = tempfile.mkdtemp() model_file = os.path.join(output_dir, "model.pdmodel")