diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index d99683d481450309d95d13dfb26b0bc3471ea5e3..72f4cc5d2a664f9225ee86a3ca36499e2ce774d5 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -428,10 +428,11 @@ class GroupShardedStage3(nn.Layer): place=core.CPUPlace(), name="slice@" + param.name, ) - with device_guard(): - param.master_weight = paddle.cast( - param.fw_storage, Type.fp32.value - ) + if param.trainable: + with device_guard(): + param.master_weight = paddle.cast( + param.fw_storage, Type.fp32.value + ) else: param.fw_storage = core.eager.Tensor( value=buffer._slice(start, end), name="slice@" + param.name diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py index 07ebba2bedfa2790c9f768fab1c34c283d6779f2..15b02b2cea625190be9e30f4943e4da7c52b1a4f 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py @@ -40,6 +40,9 @@ class MLP(fluid.Layer): self._linear1 = Linear(linear_size, linear_size) self._linear2 = Linear(linear_size, linear_size) + # test for trainable & untrainable offload + self._linear2.weight.stop_gradient = False + self._linear2.bias.stop_gradient = False self._linear3 = Linear(linear_size, 10) def forward(self, inputs): @@ -119,7 +122,7 @@ def train_mlp( img, label = data label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(True, level='O2'): + with paddle.amp.auto_cast(use_pure_fp16, level='O2'): out = model(img) loss = paddle.nn.functional.cross_entropy( input=out, label=label