diff --git a/test/auto_parallel/test_pass_sharding.py b/test/auto_parallel/test_pass_sharding.py index d9f514e9473133cda1023050dc49dc1b86ee61f3..d6062f216f66aff17bd0e2b68ec495a74d4d172b 100644 --- a/test/auto_parallel/test_pass_sharding.py +++ b/test/auto_parallel/test_pass_sharding.py @@ -36,6 +36,8 @@ class TestShardingPass(unittest.TestCase): + [ "-m", "paddle.distributed.launch", + "--devices", + "0,1", "--log_dir", tmp_dir.name, launch_model_path,