未验证 提交 90f44c6f 编写于 作者: B Baibaifan 提交者: GitHub

fix_stage2_minimize (#39285)

上级 0bb3e5f1
...@@ -70,7 +70,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -70,7 +70,7 @@ class ShardingOptimizerStage2(Optimizer):
device="gpu", device="gpu",
**kw): **kw):
# super().__init__(optim._learning_rate, params, kw) super().__init__(optim._learning_rate, params, kw)
# Segmentation information # Segmentation information
self._dtype_rank_params = OrderedDict( self._dtype_rank_params = OrderedDict(
...@@ -363,6 +363,10 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -363,6 +363,10 @@ class ShardingOptimizerStage2(Optimizer):
# Synchronize all the updated shards in between the ranks # Synchronize all the updated shards in between the ranks
self._broadcast_params() self._broadcast_params()
def minimize(self):
raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()")
def _clear_cache(self): def _clear_cache(self):
self.__segment_params.clear() self.__segment_params.clear()
self._dtype_rank_params.clear() self._dtype_rank_params.clear()
......
...@@ -506,7 +506,13 @@ class ShardingStage3(nn.Layer): ...@@ -506,7 +506,13 @@ class ShardingStage3(nn.Layer):
else: else:
opt_step() opt_step()
def _opt_minimize(self):
raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()"
)
self._optim.step = MethodType(_opt_step, self._optim) self._optim.step = MethodType(_opt_step, self._optim)
self._optim.minimize = MethodType(_opt_minimize, self._optim)
def _redefine_opt_clear(self): def _redefine_opt_clear(self):
clear_func = self._clear_gradients clear_func = self._clear_gradients
......
...@@ -124,8 +124,17 @@ def train_mlp(): ...@@ -124,8 +124,17 @@ def train_mlp():
avg_loss.backward() avg_loss.backward()
oss_optimizer.step() oss_optimizer.step()
# oss_optimizer clear cache # oss_optimizer clear cache
oss_optimizer._clear_cache() oss_optimizer._clear_cache()
# check optimizer.minimize() error
try:
oss_optimizer.minimize()
except:
print(
"====== Find sharding_stage2_optimizer.minimize() error ======"
)
return
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -83,7 +83,8 @@ def train_mlp(model, ...@@ -83,7 +83,8 @@ def train_mlp(model,
accumulate_grad=False, accumulate_grad=False,
batch_size=100, batch_size=100,
opt_group=False, opt_group=False,
recompute=False): recompute=False,
test_minimize=False):
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
if opt_group: if opt_group:
optimizer = optimizer_setting( optimizer = optimizer_setting(
...@@ -113,6 +114,15 @@ def train_mlp(model, ...@@ -113,6 +114,15 @@ def train_mlp(model,
accumulate_grads=batch_size == 20, accumulate_grads=batch_size == 20,
sync_comm=recompute) sync_comm=recompute)
# check optimizer.minimize() error
if test_minimize:
try:
optimizer.minimize()
except:
print(
"====== Find sharding_stage3_optimizer.minimize() error ======")
return
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True) reader_decorator(), batch_size=batch_size, drop_last=True)
...@@ -160,8 +170,8 @@ def train_mlp(model, ...@@ -160,8 +170,8 @@ def train_mlp(model,
def test_stage2_stage3(): def test_stage2_stage3():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP( mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9 = MLP(), MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
state_dict = mlp.state_dict() state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict) mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict) mlp2.set_state_dict(state_dict)
...@@ -171,6 +181,8 @@ def test_stage2_stage3(): ...@@ -171,6 +181,8 @@ def test_stage2_stage3():
mlp6.set_state_dict(state_dict) mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict) mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict) mlp8.set_state_dict(state_dict)
mlp9.set_state_dict(state_dict)
# fp32 # fp32
stage2_params = train_mlp( stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False) mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False)
...@@ -229,7 +241,14 @@ def test_stage2_stage3(): ...@@ -229,7 +241,14 @@ def test_stage2_stage3():
for i in range(len(stage3_params)): for i in range(len(stage3_params)):
np.testing.assert_allclose( np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)
return
# check optimizer.minimize() error
train_mlp(
mlp9,
sharding_stage=3,
use_pure_fp16=False,
opt_group=False,
test_minimize=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册