diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index ac41b4af4c9b0abf70b84a716ce561bbfed8cd8a..ad4d53cb08254e0a41934405b3a963756a7c1053 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -346,7 +346,7 @@ class GroupShardedStage3(nn.Layer): current_params = list() for p in current_layer_params: - if p.trainable and p._numel() > self._segment_size: + if p._numel() > self._segment_size: current_params.append(_add_manage_info(p)) elif p.trainable: self._unslice_params.add(_UnsliceParam(p)) @@ -430,7 +430,11 @@ class GroupShardedStage3(nn.Layer): param.status = "part" # Updata optimizer master weights - if param.dtype == Type.fp16.value and not self._offload: + if ( + param.trainable + and param.dtype == Type.fp16.value + and not self._offload + ): master_tensor = paddle.cast(param.fw_storage, Type.fp32.value) master_tensor.name = param.name self._optim._master_weights[param.fw_storage.name] = master_tensor @@ -599,6 +603,9 @@ class GroupShardedStage3(nn.Layer): def _get_allreduce_fn(self, param): @paddle.autograd.no_grad() def allreduce_(*_): + assert ( + param.trainable + ), "the param must be trainable for grad allreduced" if param.name in self._task_flow.full_grad.keys(): full_grad = self._task_flow.full_grad[param.name] # Only support sync allreduce current rank's layer now @@ -962,6 +969,8 @@ def _allgather_buffer( @paddle.autograd.no_grad() def _create_params_grad(trainable_params, param2buffer_size, task_flow): for param in trainable_params: + if not param.trainable: + continue if param.name in task_flow.full_grad.keys(): continue assert isinstance(param2buffer_size[param.name], int) diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py index a69718261d9092b5e91d2e7965ca90cb6bb58f00..012008913eee099fe4b5ee00a9c17a43626c1f23 100644 --- a/python/paddle/distributed/sharding/group_sharded.py +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -140,7 +140,9 @@ def group_sharded_parallel( params_fp16 = list(filter(check_dtype, model.parameters())) if scaler is None and len(params_fp16) > 0: - raise ValueError("Please enter the correct scaler.") + logger_.warning( + "the input of scaler is None, please ensure the logic of your scaler outside is same as GroupShardedScaler." + ) # convert model/optimizer/scaler if level in ['os', 'os_g']: logger_.info("*" * 30) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_eager.py new file mode 100644 index 0000000000000000000000000000000000000000..efd7a7b1ce70c147087652c360fbf080251e50dd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_eager.py @@ -0,0 +1,178 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +import paddle +from paddle import nn +from paddle.distributed.sharding import group_sharded_parallel +from paddle.fluid.framework import _test_eager_guard + +paddle.seed(2022) +np.random.seed(2022) + + +class Model(nn.Layer): + def __init__(self): + super(Model, self).__init__() + self.first_stage = nn.Linear(4096, 4096, bias_attr=False) + self.center_stage = nn.Linear(4096, 4096) + self.center_stage.weight.stop_gradient = True + self.center_stage.bias.stop_gradient = True + self.final_stage = nn.Linear(4096, 2, bias_attr=False) + + def forward(self, x): + x = self.first_stage(x) + x = self.center_stage(x) + x = self.final_stage(x) + return x + + +def optimizer_setting(model, use_multi_precision): + optimizer = paddle.optimizer.AdamW( + learning_rate=0.001, + parameters=model.parameters(), + multi_precision=use_multi_precision, + ) + return optimizer + + +def train_mlp( + model, + shard_level="p_g_os", + use_multi_precision=False, + output_dir="", + amp_level='O1', + sync_buffers=False, + use_sharding=True, + data=None, +): + optimizer = optimizer_setting( + model=model, use_multi_precision=use_multi_precision + ) + if use_multi_precision: + model = paddle.amp.decorate(models=model, level=amp_level) + + scaler = paddle.amp.GradScaler(init_loss_scaling=32768) + + if use_sharding: + model, optimizer, scaler = group_sharded_parallel( + model=model, + optimizer=optimizer, + level=shard_level, + scaler=scaler, + sync_buffers=sync_buffers, + ) + + res_loss = [] + for i in range(20): + model.train() + img = data[i] + with paddle.amp.auto_cast(use_multi_precision, level=amp_level): + out = model(img) + avg_loss = out.mean() + + res_loss.append(avg_loss.item()) + + if not use_multi_precision: + avg_loss.backward() + optimizer.step() + else: + scaler.scale(avg_loss).backward() + scaler.step(optimizer) + scaler.update() + + optimizer.clear_grad() + + return res_loss + + +def test_sharding_api(): + paddle.distributed.init_parallel_env() + + # just test warning + model = Model() + model = paddle.amp.decorate(models=model, level="O2") + optimizer = optimizer_setting(model=model, use_multi_precision=True) + model, optimizer, scaler = group_sharded_parallel( + model=model, + optimizer=optimizer, + level="p_g_os", + ) + + data = [paddle.randn([8, 4096]) for i in range(20)] + + model = Model() + sd3_model = Model() + sd3_model.set_state_dict(model.state_dict()) + + # dp fp32 + dp_fp32_loss = train_mlp( + model, use_multi_precision=False, use_sharding=False, data=data + ) + + # stage3 fp32 + sd3_fp32_loss = train_mlp( + sd3_model, + shard_level="p_g_os", + use_multi_precision=False, + use_sharding=True, + data=data, + ) + + print("dp_fp32_loss: ", dp_fp32_loss) + print("sd3_fp32_loss: ", sd3_fp32_loss) + + for i in range(len(dp_fp32_loss)): + np.testing.assert_allclose( + np.array(dp_fp32_loss[i]), + np.array(sd3_fp32_loss[i]), + rtol=1e-8, + atol=1e-8, + ) + + model = Model() + sd3_model = Model() + sd3_model.set_state_dict(model.state_dict()) + + # dp fp16 + dp_fp16_loss = train_mlp( + model, use_multi_precision=True, use_sharding=False, data=data + ) + + # stage3 fp16 + sd3_fp16_loss = train_mlp( + sd3_model, + shard_level="p_g_os", + use_multi_precision=True, + use_sharding=True, + data=data, + ) + + print("dp_fp316_loss: ", dp_fp32_loss) + print("sd3_fp32_loss: ", sd3_fp32_loss) + + for i in range(len(dp_fp16_loss)): + np.testing.assert_allclose( + np.array(dp_fp16_loss[i]), + np.array(sd3_fp16_loss[i]), + rtol=1e-5, + atol=1e-5, + ) + + +if __name__ == '__main__': + with _test_eager_guard(): + test_sharding_api() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_group_sharded_api_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_group_sharded_api_for_eager.py index a8b9a3229bdd02908f77ab8a41df0186b7e6fcf6..ecf864cf806f67da118c27f4b9bcd77d3cb19876 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_group_sharded_api_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_group_sharded_api_for_eager.py @@ -27,6 +27,10 @@ class TestDygraphGroupSharded(TestMultipleGpus): def test_dygraph_group_sharded(self): self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py') + # check stage3 for some functions. + def test_dygraph_group_sharded(self): + self.run_mnist_2gpu('dygraph_group_sharded_stage3_eager.py') + if __name__ == "__main__": unittest.main()