未验证 提交 2c8739e8 编写于 作者: Z zhaoyingli 提交者: GitHub

use tempfile to place temporary files (#43316)

上级 00ce09e6
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
import time import time
import paddle.fluid as fluid import tempfile
import copy import copy
import os import os
import numpy as np import numpy as np
...@@ -145,7 +145,10 @@ def train(): ...@@ -145,7 +145,10 @@ def train():
engine.predict(test_dataset, batch_size, fetch_list=['label']) engine.predict(test_dataset, batch_size, fetch_list=['label'])
# save # save
engine.save('./mlp_inf', training=False, mode='predict') temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
temp_dir.cleanup()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import os import os
import sys import sys
...@@ -77,16 +78,45 @@ cluster_json = """ ...@@ -77,16 +78,45 @@ cluster_json = """
} }
""" """
mapping_josn = """
[
{
"hostname": "machine1",
"addr": "127.0.0.1",
"port": "768",
"ranks":
{
"0": [1],
"1": [0]
}
}
]
"""
class TestAutoParallelReLaunch(unittest.TestCase): class TestAutoParallelReLaunch(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_relaunch(self): def test_relaunch(self):
file_dir = os.path.dirname(os.path.abspath(__file__)) cluster_json_path = os.path.join(self.temp_dir.name,
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") "auto_parallel_cluster.json")
mapping_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_rank_mapping.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
mapping_josn_object = json.loads(mapping_josn)
with open(mapping_json_path, "w") as mapping_josn_file:
json.dump(mapping_josn_object, mapping_josn_file)
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, launch_model_path = os.path.join(file_dir,
"auto_parallel_relaunch_model.py") "auto_parallel_relaunch_model.py")
...@@ -96,24 +126,15 @@ class TestAutoParallelReLaunch(unittest.TestCase): ...@@ -96,24 +126,15 @@ class TestAutoParallelReLaunch(unittest.TestCase):
coverage_args = [] coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [ cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--cluster_topo_path", cluster_json_path, "-m", "launch", "--log_dir", self.temp_dir.name,
"--enable_auto_mapping", "True", launch_model_path "--cluster_topo_path", cluster_json_path, "--rank_mapping_path",
mapping_json_path, "--enable_auto_mapping", "True",
launch_model_path
] ]
process = subprocess.Popen(cmd) process = subprocess.Popen(cmd)
process.wait() process.wait()
self.assertEqual(process.returncode, 0) self.assertEqual(process.returncode, 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
rank_mapping_json_path = os.path.join(
file_dir, "auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import os import os
import json import json
...@@ -1968,10 +1969,17 @@ multi_cluster_json = """{ ...@@ -1968,10 +1969,17 @@ multi_cluster_json = """{
class TestCluster(unittest.TestCase): class TestCluster(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_single_machine(self): def test_single_machine(self):
# Build cluster # Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__)) cluster_json_path = os.path.join(self.temp_dir.name,
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") "auto_parallel_cluster_single.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
...@@ -1989,14 +1997,10 @@ class TestCluster(unittest.TestCase): ...@@ -1989,14 +1997,10 @@ class TestCluster(unittest.TestCase):
self.assertTrue(devices == [0, 1, 2, 3]) self.assertTrue(devices == [0, 1, 2, 3])
self.assertTrue(involved_machine_count == 1) self.assertTrue(involved_machine_count == 1)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
def test_multi_machine(self): def test_multi_machine(self):
# Build cluster # Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__)) cluster_json_path = os.path.join(self.temp_dir.name,
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") "auto_parallel_cluster_multi.json")
cluster_json_object = json.loads(multi_cluster_json) cluster_json_object = json.loads(multi_cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
...@@ -2014,10 +2018,6 @@ class TestCluster(unittest.TestCase): ...@@ -2014,10 +2018,6 @@ class TestCluster(unittest.TestCase):
self.assertTrue(devices == [5, 6, 7, 10]) self.assertTrue(devices == [5, 6, 7, 10])
self.assertTrue(involved_machine_count == 2) self.assertTrue(involved_machine_count == 2)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import os import os
import json import json
import tempfile
import paddle import paddle
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.cluster import Cluster
...@@ -32,10 +33,16 @@ from test_cluster import cluster_json, multi_cluster_json ...@@ -32,10 +33,16 @@ from test_cluster import cluster_json, multi_cluster_json
class TestCommOpCost(unittest.TestCase): class TestCommOpCost(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_comm_cost(self): def test_comm_cost(self):
# Build cluster # Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__)) cluster_json_path = os.path.join(self.temp_dir.name,
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") "auto_parallel_cluster0.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
...@@ -92,14 +99,10 @@ class TestCommOpCost(unittest.TestCase): ...@@ -92,14 +99,10 @@ class TestCommOpCost(unittest.TestCase):
comm_context=comm_context) comm_context=comm_context)
self.assertTrue(identity_op_cost.time >= 0) self.assertTrue(identity_op_cost.time >= 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
def test_cross_machine_comm_cost(self): def test_cross_machine_comm_cost(self):
# Build cluster # Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__)) cluster_json_path = os.path.join(self.temp_dir.name,
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") "auto_parallel_cluster1.json")
cluster_json_object = json.loads(multi_cluster_json) cluster_json_object = json.loads(multi_cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
...@@ -151,10 +154,6 @@ class TestCommOpCost(unittest.TestCase): ...@@ -151,10 +154,6 @@ class TestCommOpCost(unittest.TestCase):
comm_context=comm_context) comm_context=comm_context)
self.assertTrue(recv_op_cost.time > 0) self.assertTrue(recv_op_cost.time > 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import os import os
import sys import sys
...@@ -31,24 +32,17 @@ class TestEngineAPI(unittest.TestCase): ...@@ -31,24 +32,17 @@ class TestEngineAPI(unittest.TestCase):
else: else:
coverage_args = [] coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [ cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--gpus", "0,1", launch_model_path "-m", "launch", "--gpus", "0,1", "--log_dir", tmp_dir.name,
launch_model_path
] ]
process = subprocess.Popen(cmd) process = subprocess.Popen(cmd)
process.wait() process.wait()
self.assertEqual(process.returncode, 0) self.assertEqual(process.returncode, 0)
# Remove unnecessary files tmp_dir.cleanup()
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
files_path = [path for path in os.listdir('.') if '.pd' in path]
for path in files_path:
if os.path.exists(path):
os.remove(path)
if os.path.exists('rank_mapping.csv'):
os.remove('rank_mapping.csv')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import os import os
import json import json
import tempfile
import paddle import paddle
import paddle.distributed.auto_parallel.cost as cost_model import paddle.distributed.auto_parallel.cost as cost_model
...@@ -36,6 +37,12 @@ def check_cost(cost): ...@@ -36,6 +37,12 @@ def check_cost(cost):
class TestCost(unittest.TestCase): class TestCost(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_base_cost(self): def test_base_cost(self):
cost = cost_model.Cost(memory=100, flops=200, time=0.5) cost = cost_model.Cost(memory=100, flops=200, time=0.5)
self.assertTrue(check_cost(cost)) self.assertTrue(check_cost(cost))
...@@ -65,8 +72,8 @@ class TestCost(unittest.TestCase): ...@@ -65,8 +72,8 @@ class TestCost(unittest.TestCase):
def test_comm_cost(self): def test_comm_cost(self):
# Build cluster # Build cluster
file_dir = os.path.dirname(os.path.abspath(__file__)) cluster_json_path = os.path.join(self.temp_dir.name,
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") "auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
...@@ -85,10 +92,6 @@ class TestCost(unittest.TestCase): ...@@ -85,10 +92,6 @@ class TestCost(unittest.TestCase):
op_desc=desc, comm_context=CommContext(cluster)) op_desc=desc, comm_context=CommContext(cluster))
self.assertTrue(check_cost(allreduce_cost.cost)) self.assertTrue(check_cost(allreduce_cost.cost))
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
def test_cost_estimator(self): def test_cost_estimator(self):
train_program = paddle.static.Program() train_program = paddle.static.Program()
cost_estimator = cost_model.CostEstimator(train_program) cost_estimator = cost_model.CostEstimator(train_program)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import os import os
import sys import sys
...@@ -23,14 +24,29 @@ from paddle.distributed.fleet.launch_utils import run_with_coverage ...@@ -23,14 +24,29 @@ from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestPlannerReLaunch(unittest.TestCase): class TestPlannerReLaunch(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_relaunch_with_planner(self): def test_relaunch_with_planner(self):
from test_auto_parallel_relaunch import cluster_json from test_auto_parallel_relaunch import cluster_json, mapping_josn
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") cluster_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_cluster.json")
mapping_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_rank_mapping.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
mapping_json_object = json.loads(mapping_josn)
with open(mapping_json_path, "w") as mapping_json_file:
json.dump(mapping_json_object, mapping_json_file)
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join( launch_model_path = os.path.join(
file_dir, "auto_parallel_relaunch_with_gpt_planner.py") file_dir, "auto_parallel_relaunch_with_gpt_planner.py")
...@@ -40,28 +56,15 @@ class TestPlannerReLaunch(unittest.TestCase): ...@@ -40,28 +56,15 @@ class TestPlannerReLaunch(unittest.TestCase):
coverage_args = [] coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [ cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--cluster_topo_path", cluster_json_path, "-m", "launch", "--log_dir", self.temp_dir.name,
"--enable_auto_mapping", "True", launch_model_path "--cluster_topo_path", cluster_json_path, "--rank_mapping_path",
mapping_json_path, "--enable_auto_mapping", "True",
launch_model_path
] ]
process = subprocess.Popen(cmd) process = subprocess.Popen(cmd)
process.wait() process.wait()
self.assertEqual(process.returncode, 0) self.assertEqual(process.returncode, 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
rank_mapping_json_path = os.path.join(
file_dir, "auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
files_path = [path for path in os.listdir('.') if '.pkl' in path]
for path in files_path:
if os.path.exists(path):
os.remove(path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import os import os
import sys import sys
...@@ -23,14 +24,29 @@ from paddle.distributed.fleet.launch_utils import run_with_coverage ...@@ -23,14 +24,29 @@ from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestPlannerReLaunch(unittest.TestCase): class TestPlannerReLaunch(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_relaunch_with_planner(self): def test_relaunch_with_planner(self):
from test_auto_parallel_relaunch import cluster_json from test_auto_parallel_relaunch import cluster_json, mapping_josn
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json") cluster_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_cluster.json")
mapping_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_rank_mapping.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
mapping_json_object = json.loads(mapping_josn)
with open(mapping_json_path, "w") as mapping_json_file:
json.dump(mapping_json_object, mapping_json_file)
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join( launch_model_path = os.path.join(
file_dir, "auto_parallel_relaunch_with_planner.py") file_dir, "auto_parallel_relaunch_with_planner.py")
...@@ -40,24 +56,15 @@ class TestPlannerReLaunch(unittest.TestCase): ...@@ -40,24 +56,15 @@ class TestPlannerReLaunch(unittest.TestCase):
coverage_args = [] coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [ cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--cluster_topo_path", cluster_json_path, "-m", "launch", "--log_dir", self.temp_dir.name,
"--enable_auto_mapping", "True", launch_model_path "--cluster_topo_path", cluster_json_path, "--rank_mapping_path",
mapping_json_path, "--enable_auto_mapping", "True",
launch_model_path
] ]
process = subprocess.Popen(cmd) process = subprocess.Popen(cmd)
process.wait() process.wait()
self.assertEqual(process.returncode, 0) self.assertEqual(process.returncode, 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
rank_mapping_json_path = os.path.join(
file_dir, "auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import tempfile
import unittest import unittest
import os import os
import json import json
...@@ -201,15 +202,21 @@ cluster_json = """ ...@@ -201,15 +202,21 @@ cluster_json = """
class TestAutoParallelCluster(unittest.TestCase): class TestAutoParallelCluster(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_cluster(self): def test_cluster(self):
cluster_json_file = "" cluster_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open("./auto_parallel_cluster.json", "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster() cluster = Cluster()
cluster.build_from_file("./auto_parallel_cluster.json") cluster.build_from_file(cluster_json_path)
os.remove("./auto_parallel_cluster.json")
self.assertEqual(len(cluster.get_all_devices("GPU")), 4) self.assertEqual(len(cluster.get_all_devices("GPU")), 4)
self.assertEqual(len(cluster.get_all_devices("CPU")), 2) self.assertEqual(len(cluster.get_all_devices("CPU")), 2)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import tempfile
import unittest import unittest
import os import os
import json import json
...@@ -527,14 +528,20 @@ def get_device_local_ids(machine): ...@@ -527,14 +528,20 @@ def get_device_local_ids(machine):
class TestAutoParallelMapper(unittest.TestCase): class TestAutoParallelMapper(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_mapper_dp_mp_pp(self): def test_mapper_dp_mp_pp(self):
cluster_json_file = "" cluster_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json) cluster_json_object = json.loads(cluster_json)
with open("./auto_parallel_cluster.json", "w") as cluster_json_file: with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file) json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster() cluster = Cluster()
cluster.build_from_file("./auto_parallel_cluster.json") cluster.build_from_file(cluster_json_path)
os.remove("./auto_parallel_cluster.json")
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp_pp" _global_parallel_strategy = "dp_mp_pp"
......
...@@ -892,25 +892,6 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -892,25 +892,6 @@ class TestGPTPartitioner(unittest.TestCase):
auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition( auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
complete_train_program, startup_program, params_grads) complete_train_program, startup_program, params_grads)
with open("./test_auto_parallel_partitioner_serial_main_new.txt",
"w") as fw:
fw.write(str(train_program))
with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
"w") as fw:
fw.write(str(startup_program))
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
set_default_distributed_context(dist_context)
with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
fw.write(str(auto_parallel_main_prog))
with open("./test_auto_parallel_partitioner_startup_new.txt1",
"w") as fw:
fw.write(str(auto_parallel_startup_prog))
# with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
# from paddle.distributed.auto_parallel.completion import Completer
# completer = Completer()
# completer.complete_forward_annotation(auto_parallel_main_prog)
# fw.write(str(auto_parallel_main_prog))
nrank = 4 nrank = 4
# col parallel # col parallel
weights = [ weights = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册