未验证 提交 147e7a38 编写于 作者: S sneaxiy 提交者: GitHub

Fix hybrid_parallel_sharding_model.py ut (#55269)

* fix hybrid_parallel_sharding_model.py

* Update hybrid_parallel_sharding_model.py
上级 11c26f26
...@@ -296,36 +296,38 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -296,36 +296,38 @@ class TestDistMPTraning(unittest.TestCase):
def test_sharding_adam(self): def test_sharding_adam(self):
sharded_accumulators = { sharded_accumulators = {
'linear_0.w_0_moment1_0', "linear_0.b_0_moment2_0",
'embedding_0.w_0_beta1_pow_acc_0',
'linear_2.b_0_beta2_pow_acc_0',
'linear_0.b_0_beta1_pow_acc_0',
'linear_2.b_0_moment2_0',
'linear_0.b_0_beta2_pow_acc_0',
'linear_1.b_0_moment1_0', 'linear_1.b_0_moment1_0',
'embedding_0.w_0_moment2_0',
'linear_1.b_0_moment2_0',
'linear_2.b_0_beta1_pow_acc_0',
'linear_0.b_0_moment1_0',
'linear_2.b_0_moment1_0', 'linear_2.b_0_moment1_0',
'embedding_0.w_0_moment1_0', 'embedding_0.w_0_moment1_0',
'linear_0.w_0_moment2_0', 'embedding_0.w_0_beta2_pow_acc_0',
'linear_1.b_0_moment2_0',
'linear_2.b_0_moment2_0',
'embedding_0.w_0_moment2_0',
'linear_0.w_0_beta1_pow_acc_0',
'linear_1.b_0_beta1_pow_acc_0', 'linear_1.b_0_beta1_pow_acc_0',
'linear_2.b_0_beta1_pow_acc_0',
'embedding_0.w_0_beta1_pow_acc_0',
'linear_0.w_0_beta2_pow_acc_0',
'linear_1.b_0_beta2_pow_acc_0', 'linear_1.b_0_beta2_pow_acc_0',
'linear_2.b_0_beta2_pow_acc_0',
'embedding_0.w_0_beta2_pow_acc_0',
} }
self.sharding_model( self.sharding_model(
Optimizer="adam", sharded_accumulators=sharded_accumulators Optimizer="adam",
sharded_accumulators=sharded_accumulators,
) )
def test_sharding_momentum(self): def test_sharding_momentum(self):
sharded_accumulators = { sharded_accumulators = {
'linear_6.w_0_velocity_0',
'linear_7.b_0_velocity_0',
'linear_8.b_0_velocity_0', 'linear_8.b_0_velocity_0',
'embedding_2.w_0_velocity_0', 'embedding_2.w_0_velocity_0',
'linear_6.b_0_velocity_0',
'linear_7.b_0_velocity_0',
} }
self.sharding_model( self.sharding_model(
Optimizer="Momentum", sharded_accumulators=sharded_accumulators Optimizer="Momentum",
sharded_accumulators=sharded_accumulators,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册