未验证 提交 438ca7f6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix unittest with paddle.distributed.launch (#44439)

* fix unittest

* fix log_dir

* _enable_legacy_dygraph
上级 98e96853
......@@ -16,10 +16,12 @@ from collections import OrderedDict
import paddle
import paddle.fluid.core as core
from ..collective import _get_global_env
from ..collective import _new_ring_id
from ...fluid.framework import _non_static_mode
from ...fluid.layers.tensor import fill_constant
from paddle.fluid.framework import _enable_legacy_dygraph
def get_all_process_groups():
......@@ -134,7 +136,8 @@ class ProcessGroup:
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
paddle.framework._in_legacy_dygraph()
paddle.disable_static()
_enable_legacy_dygraph()
paddle.set_device('gpu:%d' %
paddle.distributed.ParallelEnv().dev_id)
tmp = paddle.to_tensor(
......
......@@ -126,7 +126,7 @@ class TestAutoParallelReLaunch(unittest.TestCase):
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--log_dir", self.temp_dir.name,
"-m", "paddle.distributed.launch", "--log_dir", self.temp_dir.name,
"--cluster_topo_path", cluster_json_path, "--rank_mapping_path",
mapping_json_path, "--enable_auto_mapping", "True",
launch_model_path
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
......@@ -32,18 +33,17 @@ class TestConverter(unittest.TestCase):
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", launch_model_path
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
tmp_dir.cleanup()
def test_input_invalid(self):
with self.assertRaises(ValueError):
......
......@@ -34,8 +34,8 @@ class TestEngineAPI(unittest.TestCase):
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", "--log_dir", tmp_dir.name,
launch_model_path
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
......
......@@ -34,8 +34,8 @@ class TestEngineAPI(unittest.TestCase):
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", "--log_dir", tmp_dir.name,
launch_model_path
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
......@@ -31,18 +32,17 @@ class TestHighOrderGrad(unittest.TestCase):
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", launch_model_path
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
tmp_dir.cleanup()
if __name__ == "__main__":
......
......@@ -56,7 +56,7 @@ class TestPlannerReLaunch(unittest.TestCase):
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--log_dir", self.temp_dir.name,
"-m", "paddle.distributed.launch", "--log_dir", self.temp_dir.name,
"--cluster_topo_path", cluster_json_path, "--rank_mapping_path",
mapping_json_path, "--enable_auto_mapping", "True",
launch_model_path
......
......@@ -56,7 +56,7 @@ class TestPlannerReLaunch(unittest.TestCase):
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--log_dir", self.temp_dir.name,
"-m", "paddle.distributed.launch", "--log_dir", self.temp_dir.name,
"--cluster_topo_path", cluster_json_path, "--rank_mapping_path",
mapping_json_path, "--enable_auto_mapping", "True",
launch_model_path
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册