未验证 提交 08524758 编写于 作者: 姜永久 提交者: GitHub

remove fleet eager guard tests (#48765)

上级 f3982a9d
...@@ -24,7 +24,6 @@ from paddle.distributed.sharding import ( ...@@ -24,7 +24,6 @@ from paddle.distributed.sharding import (
group_sharded_parallel, group_sharded_parallel,
save_group_sharded_model, save_group_sharded_model,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -196,7 +195,5 @@ def test_sharding_api(): ...@@ -196,7 +195,5 @@ def test_sharding_api():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard():
pass
fleet.init(is_collective=True) fleet.init(is_collective=True)
test_sharding_api() test_sharding_api()
...@@ -22,7 +22,6 @@ from paddle.distributed.sharding import ( ...@@ -22,7 +22,6 @@ from paddle.distributed.sharding import (
group_sharded_parallel, group_sharded_parallel,
save_group_sharded_model, save_group_sharded_model,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -199,5 +198,4 @@ def test_sharding_api(): ...@@ -199,5 +198,4 @@ def test_sharding_api():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard(): test_sharding_api()
test_sharding_api()
...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_sta ...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_sta
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import (
GroupShardedStage2, GroupShardedStage2,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
seed = 2022 seed = 2022
...@@ -246,5 +245,4 @@ def test_dp_stage2(): ...@@ -246,5 +245,4 @@ def test_dp_stage2():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard(): test_dp_stage2()
test_dp_stage2()
...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_sta ...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_sta
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import (
GroupShardedStage2, GroupShardedStage2,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
seed = 2022 seed = 2022
...@@ -250,5 +249,4 @@ def test_dp_stage2(): ...@@ -250,5 +249,4 @@ def test_dp_stage2():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard(): test_dp_stage2()
test_dp_stage2()
...@@ -31,7 +31,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import ...@@ -31,7 +31,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler, GroupShardedScaler,
) )
from paddle.fluid.framework import _test_eager_guard
seed = 2021 seed = 2021
epoch = 2 epoch = 2
...@@ -115,5 +114,4 @@ def test_sharding_stage2_offload(): ...@@ -115,5 +114,4 @@ def test_sharding_stage2_offload():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard(): test_sharding_stage2_offload()
test_sharding_stage2_offload()
...@@ -34,7 +34,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import ...@@ -34,7 +34,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler, GroupShardedScaler,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -320,5 +319,4 @@ def test_stage2_stage3(): ...@@ -320,5 +319,4 @@ def test_stage2_stage3():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard(): test_stage2_stage3()
test_stage2_stage3()
...@@ -24,7 +24,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import ...@@ -24,7 +24,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler, GroupShardedScaler,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -220,5 +219,4 @@ def test_stage3_offload(): ...@@ -220,5 +219,4 @@ def test_stage3_offload():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard(): test_stage3_offload()
test_stage3_offload()
...@@ -23,7 +23,6 @@ from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimiz ...@@ -23,7 +23,6 @@ from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimiz
ShardingOptimizerStage2, ShardingOptimizerStage2,
) )
from paddle.distributed.fleet.utils.internal_storage import GradStorage from paddle.distributed.fleet.utils.internal_storage import GradStorage
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
base_lr = 0.1 base_lr = 0.1
...@@ -142,6 +141,4 @@ def train_mlp(): ...@@ -142,6 +141,4 @@ def train_mlp():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard():
pass
train_mlp() train_mlp()
...@@ -29,7 +29,6 @@ from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimiz ...@@ -29,7 +29,6 @@ from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimiz
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ( from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import (
ShardingStage2, ShardingStage2,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
seed = 2022 seed = 2022
...@@ -239,7 +238,5 @@ def test_dp_stage2(): ...@@ -239,7 +238,5 @@ def test_dp_stage2():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard():
pass
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
test_dp_stage2() test_dp_stage2()
...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ( ...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import (
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import (
ShardingScaler, ShardingScaler,
) )
from paddle.fluid.framework import _test_eager_guard
seed = 2021 seed = 2021
epoch = 2 epoch = 2
...@@ -119,7 +118,5 @@ def test_sharding_stage2_offload(): ...@@ -119,7 +118,5 @@ def test_sharding_stage2_offload():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard():
pass
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
test_sharding_stage2_offload() test_sharding_stage2_offload()
...@@ -35,7 +35,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ( ...@@ -35,7 +35,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import (
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import (
ShardingScaler, ShardingScaler,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -316,7 +315,5 @@ def test_stage2_stage3(): ...@@ -316,7 +315,5 @@ def test_stage2_stage3():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard():
pass
fleet.init(is_collective=True) fleet.init(is_collective=True)
test_stage2_stage3() test_stage2_stage3()
...@@ -25,7 +25,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ( ...@@ -25,7 +25,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import (
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import (
ShardingScaler, ShardingScaler,
) )
from paddle.fluid.framework import _test_eager_guard
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -216,7 +215,5 @@ def test_stage3_offload(): ...@@ -216,7 +215,5 @@ def test_stage3_offload():
if __name__ == '__main__': if __name__ == '__main__':
with _test_eager_guard():
pass
fleet.init(is_collective=True) fleet.init(is_collective=True)
test_stage3_offload() test_stage3_offload()
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import unittest import unittest
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard
class TestProcessGroupFp32(unittest.TestCase): class TestProcessGroupFp32(unittest.TestCase):
...@@ -26,15 +25,14 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -26,15 +25,14 @@ class TestProcessGroupFp32(unittest.TestCase):
pass pass
def test_init_process_group(self): def test_init_process_group(self):
with _test_eager_guard(): paddle.distributed.init_parallel_env()
paddle.distributed.init_parallel_env() paddle.distributed.new_group()
paddle.distributed.new_group() group = paddle.distributed.new_group([-1, -2])
group = paddle.distributed.new_group([-1, -2]) assert group.process_group is None
assert group.process_group is None
group = paddle.distributed.collective.Group(-1, 2, 0, [-1, -2])
group = paddle.distributed.collective.Group(-1, 2, 0, [-1, -2]) ret = paddle.distributed.barrier(group)
ret = paddle.distributed.barrier(group) assert ret is None
assert ret is None
paddle.enable_static() paddle.enable_static()
in_tensor = paddle.empty((1, 2)) in_tensor = paddle.empty((1, 2))
in_tensor2 = paddle.empty((1, 2)) in_tensor2 = paddle.empty((1, 2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册