未验证 提交 90f44c6f 编写于 作者: B Baibaifan 提交者: GitHub

fix_stage2_minimize (#39285)

上级 0bb3e5f1
......@@ -70,7 +70,7 @@ class ShardingOptimizerStage2(Optimizer):
device="gpu",
**kw):
# super().__init__(optim._learning_rate, params, kw)
super().__init__(optim._learning_rate, params, kw)
# Segmentation information
self._dtype_rank_params = OrderedDict(
......@@ -363,6 +363,10 @@ class ShardingOptimizerStage2(Optimizer):
# Synchronize all the updated shards in between the ranks
self._broadcast_params()
def minimize(self):
raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()")
def _clear_cache(self):
self.__segment_params.clear()
self._dtype_rank_params.clear()
......
......@@ -506,7 +506,13 @@ class ShardingStage3(nn.Layer):
else:
opt_step()
def _opt_minimize(self):
raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()"
)
self._optim.step = MethodType(_opt_step, self._optim)
self._optim.minimize = MethodType(_opt_minimize, self._optim)
def _redefine_opt_clear(self):
clear_func = self._clear_gradients
......
......@@ -124,8 +124,17 @@ def train_mlp():
avg_loss.backward()
oss_optimizer.step()
# oss_optimizer clear cache
oss_optimizer._clear_cache()
# oss_optimizer clear cache
oss_optimizer._clear_cache()
# check optimizer.minimize() error
try:
oss_optimizer.minimize()
except:
print(
"====== Find sharding_stage2_optimizer.minimize() error ======"
)
return
if __name__ == '__main__':
......
......@@ -83,7 +83,8 @@ def train_mlp(model,
accumulate_grad=False,
batch_size=100,
opt_group=False,
recompute=False):
recompute=False,
test_minimize=False):
group = paddle.distributed.new_group([0, 1])
if opt_group:
optimizer = optimizer_setting(
......@@ -113,6 +114,15 @@ def train_mlp(model,
accumulate_grads=batch_size == 20,
sync_comm=recompute)
# check optimizer.minimize() error
if test_minimize:
try:
optimizer.minimize()
except:
print(
"====== Find sharding_stage3_optimizer.minimize() error ======")
return
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
......@@ -160,8 +170,8 @@ def train_mlp(model,
def test_stage2_stage3():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9 = MLP(), MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
......@@ -171,6 +181,8 @@ def test_stage2_stage3():
mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict)
mlp9.set_state_dict(state_dict)
# fp32
stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False)
......@@ -229,7 +241,14 @@ def test_stage2_stage3():
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)
return
# check optimizer.minimize() error
train_mlp(
mlp9,
sharding_stage=3,
use_pure_fp16=False,
opt_group=False,
test_minimize=True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册