未验证 提交 0a0f3ef9 编写于 作者: W Wennie396 提交者: GitHub

Fix sharding_pass_unittest stage3 precision problem (#55613)

* fix sharding_pass stage3 precision problem

* delete 'stage3 has precision problem' comment

* add dp2 training after load dp_engine

* unset grad_clip=clip for opt
上级 53584ebc
...@@ -27,7 +27,7 @@ paddle.enable_static() ...@@ -27,7 +27,7 @@ paddle.enable_static()
def apply_pass(use_sharding=False, stage=None): def apply_pass(use_sharding=False, stage=None):
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
strategy.reinit = True # strategy.reinit = True
if use_sharding: if use_sharding:
sharding = strategy.sharding sharding = strategy.sharding
sharding.enable = True sharding.enable = True
...@@ -50,6 +50,7 @@ def apply_pass(use_sharding=False, stage=None): ...@@ -50,6 +50,7 @@ def apply_pass(use_sharding=False, stage=None):
def reset_prog(): def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program()) paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program()) paddle.fluid.framework.switch_startup_program(paddle.static.Program())
paddle.utils.unique_name.switch()
class TestShardingPass(unittest.TestCase): class TestShardingPass(unittest.TestCase):
...@@ -65,15 +66,14 @@ class TestShardingPass(unittest.TestCase): ...@@ -65,15 +66,14 @@ class TestShardingPass(unittest.TestCase):
paddle.seed(2022) paddle.seed(2022)
np.random.seed(2022) np.random.seed(2022)
random.seed(2022) random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_sharding=False, stage=None): def get_engine(self, use_sharding=False, stage=None):
reset_prog() reset_prog()
strategy = apply_pass(use_sharding, stage) strategy = apply_pass(use_sharding, stage)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) # NOTE: seting opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) will cause precision problem
opt = paddle.optimizer.AdamW(learning_rate=0.00001)
model, loss = generate_model("dp") model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy) engine = auto.Engine(model, loss, opt, strategy=strategy)
...@@ -81,11 +81,9 @@ class TestShardingPass(unittest.TestCase): ...@@ -81,11 +81,9 @@ class TestShardingPass(unittest.TestCase):
return engine return engine
def check_results(self, ref_losses, check_losses): def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose( np.testing.assert_equal(
ref_losses, ref_losses,
check_losses, check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses __class__, ref_losses, check_losses, ref_losses - check_losses
), ),
...@@ -94,11 +92,40 @@ class TestShardingPass(unittest.TestCase): ...@@ -94,11 +92,40 @@ class TestShardingPass(unittest.TestCase):
def test_sharding_pass(self): def test_sharding_pass(self):
# dp2 training # dp2 training
dp_engine = self.get_engine() dp_engine = self.get_engine()
input_spec = [
paddle.static.InputSpec([self.batch_size, 512], 'int64', 'tokens'),
paddle.static.InputSpec(
[self.batch_size, 512], 'int64', 'position_ids'
),
paddle.static.InputSpec(
[self.batch_size, 1, 512, 512], 'float32', 'attention_mask'
),
]
label_spec = [
paddle.static.InputSpec([self.batch_size, 512], 'int64', 'label'),
paddle.static.InputSpec(
[self.batch_size, 512], 'float32', 'loss_mask'
),
]
dp_engine.prepare(
inputs_spec=input_spec, labels_spec=label_spec, mode='train'
)
dp_engine.save("./dp_engine", training=True)
history = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(history.history["loss"]) dp_losses = np.array(history.history["loss"])
# dp2 training after load dp_engine
dp_load_engine = self.get_engine()
dp_load_engine.load("./dp_engine")
history = dp_load_engine.fit(
self.dataset, 3, batch_size=self.batch_size
)
dp_load_losses2 = np.array(history.history["loss"])
self.check_results(dp_losses, dp_load_losses2)
# sharding2 stage1 training # sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1) sharding1_engine = self.get_engine(True, 1)
sharding1_engine.load("./dp_engine")
history = sharding1_engine.fit( history = sharding1_engine.fit(
self.dataset, 3, batch_size=self.batch_size self.dataset, 3, batch_size=self.batch_size
) )
...@@ -107,6 +134,7 @@ class TestShardingPass(unittest.TestCase): ...@@ -107,6 +134,7 @@ class TestShardingPass(unittest.TestCase):
# sharding2 stage2 training # sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2) sharding2_engine = self.get_engine(True, 2)
sharding2_engine.load("./dp_engine")
history = sharding2_engine.fit( history = sharding2_engine.fit(
self.dataset, 3, batch_size=self.batch_size self.dataset, 3, batch_size=self.batch_size
) )
...@@ -115,12 +143,12 @@ class TestShardingPass(unittest.TestCase): ...@@ -115,12 +143,12 @@ class TestShardingPass(unittest.TestCase):
# sharding2 stage3 training # sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3) sharding3_engine = self.get_engine(True, 3)
sharding3_engine.load("./dp_engine")
history = sharding3_engine.fit( history = sharding3_engine.fit(
self.dataset, 3, batch_size=self.batch_size self.dataset, 3, batch_size=self.batch_size
) )
sharding3_losses = np.array(history.history["loss"]) sharding3_losses = np.array(history.history["loss"])
# NOTE: stage3 has precision problem self.check_results(dp_losses, sharding3_losses)
# self.check_results(dp_losses, sharding3_losses)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册