未验证 提交 147e7a38 编写于 作者: S sneaxiy 提交者: GitHub

Fix hybrid_parallel_sharding_model.py ut (#55269)

* fix hybrid_parallel_sharding_model.py

* Update hybrid_parallel_sharding_model.py
上级 11c26f26
......@@ -296,36 +296,38 @@ class TestDistMPTraning(unittest.TestCase):
def test_sharding_adam(self):
sharded_accumulators = {
'linear_0.w_0_moment1_0',
"linear_0.b_0_moment2_0",
'embedding_0.w_0_beta1_pow_acc_0',
'linear_2.b_0_beta2_pow_acc_0',
'linear_0.b_0_beta1_pow_acc_0',
'linear_2.b_0_moment2_0',
'linear_0.b_0_beta2_pow_acc_0',
'linear_1.b_0_moment1_0',
'embedding_0.w_0_moment2_0',
'linear_1.b_0_moment2_0',
'linear_2.b_0_beta1_pow_acc_0',
'linear_0.b_0_moment1_0',
'linear_2.b_0_moment1_0',
'embedding_0.w_0_moment1_0',
'linear_0.w_0_moment2_0',
'linear_1.b_0_moment2_0',
'linear_2.b_0_moment2_0',
'embedding_0.w_0_moment2_0',
'linear_0.w_0_beta1_pow_acc_0',
'embedding_0.w_0_beta2_pow_acc_0',
'linear_1.b_0_beta1_pow_acc_0',
'linear_2.b_0_beta1_pow_acc_0',
'embedding_0.w_0_beta1_pow_acc_0',
'linear_0.w_0_beta2_pow_acc_0',
'linear_1.b_0_beta2_pow_acc_0',
'linear_2.b_0_beta2_pow_acc_0',
'embedding_0.w_0_beta2_pow_acc_0',
}
self.sharding_model(
Optimizer="adam", sharded_accumulators=sharded_accumulators
Optimizer="adam",
sharded_accumulators=sharded_accumulators,
)
def test_sharding_momentum(self):
sharded_accumulators = {
'linear_6.w_0_velocity_0',
'linear_7.b_0_velocity_0',
'linear_8.b_0_velocity_0',
'embedding_2.w_0_velocity_0',
'linear_6.b_0_velocity_0',
'linear_7.b_0_velocity_0',
}
self.sharding_model(
Optimizer="Momentum", sharded_accumulators=sharded_accumulators
Optimizer="Momentum",
sharded_accumulators=sharded_accumulators,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册