未验证 提交 e4ee872c 编写于 作者: W wuhuachaocoding 提交者: GitHub

update for untrainable params for stage3. (#48577)

上级 06b32b38
...@@ -346,7 +346,7 @@ class GroupShardedStage3(nn.Layer): ...@@ -346,7 +346,7 @@ class GroupShardedStage3(nn.Layer):
current_params = list() current_params = list()
for p in current_layer_params: for p in current_layer_params:
if p.trainable and p._numel() > self._segment_size: if p._numel() > self._segment_size:
current_params.append(_add_manage_info(p)) current_params.append(_add_manage_info(p))
elif p.trainable: elif p.trainable:
self._unslice_params.add(_UnsliceParam(p)) self._unslice_params.add(_UnsliceParam(p))
...@@ -430,7 +430,11 @@ class GroupShardedStage3(nn.Layer): ...@@ -430,7 +430,11 @@ class GroupShardedStage3(nn.Layer):
param.status = "part" param.status = "part"
# Updata optimizer master weights # Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload: if (
param.trainable
and param.dtype == Type.fp16.value
and not self._offload
):
master_tensor = paddle.cast(param.fw_storage, Type.fp32.value) master_tensor = paddle.cast(param.fw_storage, Type.fp32.value)
master_tensor.name = param.name master_tensor.name = param.name
self._optim._master_weights[param.fw_storage.name] = master_tensor self._optim._master_weights[param.fw_storage.name] = master_tensor
...@@ -599,6 +603,9 @@ class GroupShardedStage3(nn.Layer): ...@@ -599,6 +603,9 @@ class GroupShardedStage3(nn.Layer):
def _get_allreduce_fn(self, param): def _get_allreduce_fn(self, param):
@paddle.autograd.no_grad() @paddle.autograd.no_grad()
def allreduce_(*_): def allreduce_(*_):
assert (
param.trainable
), "the param must be trainable for grad allreduced"
if param.name in self._task_flow.full_grad.keys(): if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name] full_grad = self._task_flow.full_grad[param.name]
# Only support sync allreduce current rank's layer now # Only support sync allreduce current rank's layer now
...@@ -962,6 +969,8 @@ def _allgather_buffer( ...@@ -962,6 +969,8 @@ def _allgather_buffer(
@paddle.autograd.no_grad() @paddle.autograd.no_grad()
def _create_params_grad(trainable_params, param2buffer_size, task_flow): def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params: for param in trainable_params:
if not param.trainable:
continue
if param.name in task_flow.full_grad.keys(): if param.name in task_flow.full_grad.keys():
continue continue
assert isinstance(param2buffer_size[param.name], int) assert isinstance(param2buffer_size[param.name], int)
......
...@@ -140,7 +140,9 @@ def group_sharded_parallel( ...@@ -140,7 +140,9 @@ def group_sharded_parallel(
params_fp16 = list(filter(check_dtype, model.parameters())) params_fp16 = list(filter(check_dtype, model.parameters()))
if scaler is None and len(params_fp16) > 0: if scaler is None and len(params_fp16) > 0:
raise ValueError("Please enter the correct scaler.") logger_.warning(
"the input of scaler is None, please ensure the logic of your scaler outside is same as GroupShardedScaler."
)
# convert model/optimizer/scaler # convert model/optimizer/scaler
if level in ['os', 'os_g']: if level in ['os', 'os_g']:
logger_.info("*" * 30) logger_.info("*" * 30)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import nn
from paddle.distributed.sharding import group_sharded_parallel
from paddle.fluid.framework import _test_eager_guard
paddle.seed(2022)
np.random.seed(2022)
class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.first_stage = nn.Linear(4096, 4096, bias_attr=False)
self.center_stage = nn.Linear(4096, 4096)
self.center_stage.weight.stop_gradient = True
self.center_stage.bias.stop_gradient = True
self.final_stage = nn.Linear(4096, 2, bias_attr=False)
def forward(self, x):
x = self.first_stage(x)
x = self.center_stage(x)
x = self.final_stage(x)
return x
def optimizer_setting(model, use_multi_precision):
optimizer = paddle.optimizer.AdamW(
learning_rate=0.001,
parameters=model.parameters(),
multi_precision=use_multi_precision,
)
return optimizer
def train_mlp(
model,
shard_level="p_g_os",
use_multi_precision=False,
output_dir="",
amp_level='O1',
sync_buffers=False,
use_sharding=True,
data=None,
):
optimizer = optimizer_setting(
model=model, use_multi_precision=use_multi_precision
)
if use_multi_precision:
model = paddle.amp.decorate(models=model, level=amp_level)
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
if use_sharding:
model, optimizer, scaler = group_sharded_parallel(
model=model,
optimizer=optimizer,
level=shard_level,
scaler=scaler,
sync_buffers=sync_buffers,
)
res_loss = []
for i in range(20):
model.train()
img = data[i]
with paddle.amp.auto_cast(use_multi_precision, level=amp_level):
out = model(img)
avg_loss = out.mean()
res_loss.append(avg_loss.item())
if not use_multi_precision:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
return res_loss
def test_sharding_api():
paddle.distributed.init_parallel_env()
# just test warning
model = Model()
model = paddle.amp.decorate(models=model, level="O2")
optimizer = optimizer_setting(model=model, use_multi_precision=True)
model, optimizer, scaler = group_sharded_parallel(
model=model,
optimizer=optimizer,
level="p_g_os",
)
data = [paddle.randn([8, 4096]) for i in range(20)]
model = Model()
sd3_model = Model()
sd3_model.set_state_dict(model.state_dict())
# dp fp32
dp_fp32_loss = train_mlp(
model, use_multi_precision=False, use_sharding=False, data=data
)
# stage3 fp32
sd3_fp32_loss = train_mlp(
sd3_model,
shard_level="p_g_os",
use_multi_precision=False,
use_sharding=True,
data=data,
)
print("dp_fp32_loss: ", dp_fp32_loss)
print("sd3_fp32_loss: ", sd3_fp32_loss)
for i in range(len(dp_fp32_loss)):
np.testing.assert_allclose(
np.array(dp_fp32_loss[i]),
np.array(sd3_fp32_loss[i]),
rtol=1e-8,
atol=1e-8,
)
model = Model()
sd3_model = Model()
sd3_model.set_state_dict(model.state_dict())
# dp fp16
dp_fp16_loss = train_mlp(
model, use_multi_precision=True, use_sharding=False, data=data
)
# stage3 fp16
sd3_fp16_loss = train_mlp(
sd3_model,
shard_level="p_g_os",
use_multi_precision=True,
use_sharding=True,
data=data,
)
print("dp_fp316_loss: ", dp_fp32_loss)
print("sd3_fp32_loss: ", sd3_fp32_loss)
for i in range(len(dp_fp16_loss)):
np.testing.assert_allclose(
np.array(dp_fp16_loss[i]),
np.array(sd3_fp16_loss[i]),
rtol=1e-5,
atol=1e-5,
)
if __name__ == '__main__':
with _test_eager_guard():
test_sharding_api()
...@@ -27,6 +27,10 @@ class TestDygraphGroupSharded(TestMultipleGpus): ...@@ -27,6 +27,10 @@ class TestDygraphGroupSharded(TestMultipleGpus):
def test_dygraph_group_sharded(self): def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py') self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py')
# check stage3 for some functions.
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3_eager.py')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册