From 451756fb5ac015d26627679475442a4cf5735041 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:39:17 +0800 Subject: [PATCH] support cpu offload for stage3 (#49196) --- .../fleet/meta_parallel/sharding/group_sharded_stage3.py | 9 +++++---- .../fleet/dygraph_group_sharded_stage3_offload.py | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index d99683d481..72f4cc5d2a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -428,10 +428,11 @@ class GroupShardedStage3(nn.Layer): place=core.CPUPlace(), name="slice@" + param.name, ) - with device_guard(): - param.master_weight = paddle.cast( - param.fw_storage, Type.fp32.value - ) + if param.trainable: + with device_guard(): + param.master_weight = paddle.cast( + param.fw_storage, Type.fp32.value + ) else: param.fw_storage = core.eager.Tensor( value=buffer._slice(start, end), name="slice@" + param.name diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py index 07ebba2bed..15b02b2cea 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py @@ -40,6 +40,9 @@ class MLP(fluid.Layer): self._linear1 = Linear(linear_size, linear_size) self._linear2 = Linear(linear_size, linear_size) + # test for trainable & untrainable offload + self._linear2.weight.stop_gradient = False + self._linear2.bias.stop_gradient = False self._linear3 = Linear(linear_size, 10) def forward(self, inputs): @@ -119,7 +122,7 @@ def train_mlp( img, label = data label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(True, level='O2'): + with paddle.amp.auto_cast(use_pure_fp16, level='O2'): out = model(img) loss = paddle.nn.functional.cross_entropy( input=out, label=label -- GitLab