未验证 提交 f4623876 编写于 作者: B Baibaifan 提交者: GitHub

fix sharding stage2 unittest (#39112)

上级 3c1dc6f6
......@@ -70,7 +70,7 @@ class ShardingOptimizerStage2(Optimizer):
device="gpu",
**kw):
super().__init__(optim._learning_rate, params, kw)
# super().__init__(optim._learning_rate, params, kw)
# Segmentation information
self._dtype_rank_params = OrderedDict(
......@@ -83,8 +83,6 @@ class ShardingOptimizerStage2(Optimizer):
# Default information
self._optim_defaults = kw
self._optim = optim
self._ori_parameter_list = self._optim._parameter_list
self._ori_param_groups = self._optim._param_groups
assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"
......@@ -336,24 +334,11 @@ class ShardingOptimizerStage2(Optimizer):
if self.offload:
params_list = [self.offload_params.buffer]
else:
# Synchronize optimizer parameters for the current rank
params_list = []
for dtype in self.dtype_rank_params.keys():
params_list.extend(self.dtype_rank_params[dtype][self.rank])
params_name_list = list(map(lambda p: p.name, params_list))
if not isinstance(self._optim._param_groups[0], dict):
self._optim._parameter_list = params_list
self._optim._param_groups = params_list
else:
for param_group in self._optim._param_groups:
p_group = []
for param in param_group['params']:
if param.name in params_name_list:
p_group.append(params_list[params_name_list.index(
param.name)])
param_group['params'] = p_group
#TODO(Baibaifan): Offload will support param_groups later
if not isinstance(self._optim._param_groups[0], dict):
self._optim._parameter_list = params_list
self._optim._param_groups = params_list
# Run the optimizer of the current rank step
if self.offload:
......@@ -371,10 +356,6 @@ class ShardingOptimizerStage2(Optimizer):
# Synchronize all the updated shards in between the ranks
self._broadcast_params()
# Return full parameters to optimizer parameters
self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups
def _clear_cache(self):
self.__segment_params.clear()
self._dtype_rank_params.clear()
......
......@@ -29,7 +29,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar
seed = 2021
epoch = 2
batch_size = 32
linear_size = 1000
strategy = fleet.DistributedStrategy()
......@@ -86,6 +85,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
def train_mlp(model,
sharding_stage,
batch_size=100,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False):
......@@ -103,16 +103,13 @@ def train_mlp(model,
if sharding_stage == 2:
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
if accumulate_grad:
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
else:
model = ShardingStage2(
model, optimizer, group=group, buffer_max_size=2**21)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=batch_size == 20)
else:
optimizer = fleet.distributed_optimizer(optimizer)
model = fleet.distributed_model(model)
......@@ -145,12 +142,13 @@ def train_mlp(model,
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
avg_loss.backward()
if not accumulate_grad:
optimizer.step()
optimizer.clear_grad()
if accumulate_grad:
optimizer.step()
optimizer.clear_grad()
if accumulate_grad and batch_id == 2:
return model.parameters()
return model.parameters()
......@@ -166,25 +164,22 @@ def test_dp_stage2():
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
dp_params = train_mlp(
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=True)
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False)
stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=False)
for i in range(len(dp_params)):
for j in range(len(stage2_params)):
if dp_params[i].name == stage2_params[j].name:
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6)
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
stage2_params = train_mlp(mlp3, sharding_stage=2)
stage2_accumulate_grad = train_mlp(
mlp4, sharding_stage=2, accumulate_grad=True)
mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True)
for i in range(len(stage2_params)):
for j in range(len(stage2_accumulate_grad)):
if stage2_params[i].name == stage2_accumulate_grad[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage2_accumulate_grad[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage2_accumulate_grad[i].numpy(),
rtol=1e-5,
atol=1e-5)
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册