diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index eba71a86f576e0a5305d43bffa956ecfbc1cc7ef..d210bc12f529680f5a266a508aaf64c3ce670c2e 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1790,7 +1790,7 @@ def group_param(sharding_info, fuse_size): class ShardingInfo: def __init__(self, group, rank, params_grads, partition_algor): self.group = group - self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads]) + self.params_grads = {p.name: (p, g) for p, g in params_grads} assert len(self.params_grads) == len( set(self.params_grads) ), "found duplicated param in params_grads" diff --git a/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py index 3fa9c12529272c495644508e947d63c6a3f973b2..bcf2b7c49b6fc31c68a91392228f2d358b055092 100644 --- a/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py @@ -231,9 +231,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): def reader(): for _ in range(6): - yield dict( - [(grad.name, gen_random_grad_tensor(grad)) for grad in grads] - ) + yield {grad.name: gen_random_grad_tensor(grad) for grad in grads} scope = paddle.static.Scope() fetch_list = params