From 147e7a388acaac0042478571441bd8892460fcd8 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 12 Jul 2023 09:27:26 +0800 Subject: [PATCH] Fix hybrid_parallel_sharding_model.py ut (#55269) * fix hybrid_parallel_sharding_model.py * Update hybrid_parallel_sharding_model.py --- .../fleet/hybrid_parallel_sharding_model.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py index f56c5513782..82c132df996 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py @@ -296,36 +296,38 @@ class TestDistMPTraning(unittest.TestCase): def test_sharding_adam(self): sharded_accumulators = { - 'linear_0.w_0_moment1_0', + "linear_0.b_0_moment2_0", + 'embedding_0.w_0_beta1_pow_acc_0', + 'linear_2.b_0_beta2_pow_acc_0', + 'linear_0.b_0_beta1_pow_acc_0', + 'linear_2.b_0_moment2_0', + 'linear_0.b_0_beta2_pow_acc_0', 'linear_1.b_0_moment1_0', + 'embedding_0.w_0_moment2_0', + 'linear_1.b_0_moment2_0', + 'linear_2.b_0_beta1_pow_acc_0', + 'linear_0.b_0_moment1_0', 'linear_2.b_0_moment1_0', 'embedding_0.w_0_moment1_0', - 'linear_0.w_0_moment2_0', - 'linear_1.b_0_moment2_0', - 'linear_2.b_0_moment2_0', - 'embedding_0.w_0_moment2_0', - 'linear_0.w_0_beta1_pow_acc_0', + 'embedding_0.w_0_beta2_pow_acc_0', 'linear_1.b_0_beta1_pow_acc_0', - 'linear_2.b_0_beta1_pow_acc_0', - 'embedding_0.w_0_beta1_pow_acc_0', - 'linear_0.w_0_beta2_pow_acc_0', 'linear_1.b_0_beta2_pow_acc_0', - 'linear_2.b_0_beta2_pow_acc_0', - 'embedding_0.w_0_beta2_pow_acc_0', } self.sharding_model( - Optimizer="adam", sharded_accumulators=sharded_accumulators + Optimizer="adam", + sharded_accumulators=sharded_accumulators, ) def test_sharding_momentum(self): sharded_accumulators = { - 'linear_6.w_0_velocity_0', - 'linear_7.b_0_velocity_0', 'linear_8.b_0_velocity_0', 'embedding_2.w_0_velocity_0', + 'linear_6.b_0_velocity_0', + 'linear_7.b_0_velocity_0', } self.sharding_model( - Optimizer="Momentum", sharded_accumulators=sharded_accumulators + Optimizer="Momentum", + sharded_accumulators=sharded_accumulators, ) -- GitLab