未验证 提交 2747de2b 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Update reshard for while sub block (#40366)

* update reshard for while sub block

* fix code format error
上级 575dea8f
......@@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program
from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program
from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static()
_global_parallel_strategy = None
......@@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase):
str(paddle.distributed.get_rank())))
def test_mlp_mp2pp(self):
set_default_distributed_context(None)
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
......@@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase):
fetch_list=[loss])
last_res = res[0]
set_default_distributed_context(None)
_global_parallel_strategy = "pp"
_global_process_mesh = auto.ProcessMesh([0, 1])
global PP_MESH_0
......@@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
str(paddle.distributed.get_rank())))
def test_mlp_pp2mp(self):
set_default_distributed_context(None)
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
......@@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
if paddle.distributed.get_rank() in [1]:
last_res = res[0]
set_default_distributed_context(None)
_global_parallel_strategy = "mp"
_global_process_mesh = auto.ProcessMesh([0, 1])
......@@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
np.random.seed(2021)
def test_input_invalid(self):
set_default_distributed_context(None)
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册