未验证 提交 451756fb 编写于 作者: W wuhuachaocoding 提交者: GitHub

support cpu offload for stage3 (#49196)

上级 a36c5490
...@@ -428,6 +428,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -428,6 +428,7 @@ class GroupShardedStage3(nn.Layer):
place=core.CPUPlace(), place=core.CPUPlace(),
name="slice@" + param.name, name="slice@" + param.name,
) )
if param.trainable:
with device_guard(): with device_guard():
param.master_weight = paddle.cast( param.master_weight = paddle.cast(
param.fw_storage, Type.fp32.value param.fw_storage, Type.fp32.value
......
...@@ -40,6 +40,9 @@ class MLP(fluid.Layer): ...@@ -40,6 +40,9 @@ class MLP(fluid.Layer):
self._linear1 = Linear(linear_size, linear_size) self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size) self._linear2 = Linear(linear_size, linear_size)
# test for trainable & untrainable offload
self._linear2.weight.stop_gradient = False
self._linear2.bias.stop_gradient = False
self._linear3 = Linear(linear_size, 10) self._linear3 = Linear(linear_size, 10)
def forward(self, inputs): def forward(self, inputs):
...@@ -119,7 +122,7 @@ def train_mlp( ...@@ -119,7 +122,7 @@ def train_mlp(
img, label = data img, label = data
label.stop_gradient = True label.stop_gradient = True
img.stop_gradient = True img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'): with paddle.amp.auto_cast(use_pure_fp16, level='O2'):
out = model(img) out = model(img)
loss = paddle.nn.functional.cross_entropy( loss = paddle.nn.functional.cross_entropy(
input=out, label=label input=out, label=label
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册