未验证 提交 b2224e6f 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] fix mlu ctest final. (#44404)

上级 1d128326
...@@ -17,5 +17,4 @@ ...@@ -17,5 +17,4 @@
set -e set -e
# use default values # use default values
# FIXME: random fails on Unknown command lines -c (or -m). # FIXME: random fails on Unknown command lines -c (or -m).
launch_py=${PADDLE_BINARY_DIR}/python/paddle/distributed/launch.py MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch c_comm_init_op_mlu.py
MLU_VISIBLE_DEVICES=0,1 python ${launch_py} c_comm_init_op_mlu.py
...@@ -16,13 +16,13 @@ from __future__ import print_function ...@@ -16,13 +16,13 @@ from __future__ import print_function
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import sys
sys.path.append("..")
from op_test import OpTest from op_test import OpTest
import numpy as np import numpy as np
import unittest import unittest
import sys
sys.path.append("..")
paddle.enable_static() paddle.enable_static()
SEED = 2021 SEED = 2021
......
...@@ -102,12 +102,13 @@ class TestSpawn(unittest.TestCase): ...@@ -102,12 +102,13 @@ class TestSpawn(unittest.TestCase):
self.assertEqual(nprocs, core.get_mlu_device_count()) self.assertEqual(nprocs, core.get_mlu_device_count())
def test_spawn(self): def test_spawn(self):
context = dist.spawn(train, backend='cncl', nprocs=4) num_devs = core.get_mlu_device_count()
context = dist.spawn(train, backend='cncl', nprocs=num_devs)
rank_list = [] rank_list = []
for i in range(4): for i in range(num_devs):
rank_list.append(context.return_queues[i].get()) rank_list.append(context.return_queues[i].get())
rank_list.sort() rank_list.sort()
self.assertEqual(rank_list, list(range(4))) self.assertEqual(rank_list, list(range(num_devs)))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册