From b2224e6f391f3c400a754deaf6af32f722f1a7a0 Mon Sep 17 00:00:00 2001 From: Chenxiao Niu Date: Mon, 18 Jul 2022 19:02:18 +0800 Subject: [PATCH] [MLU] fix mlu ctest final. (#44404) --- .../fluid/tests/unittests/mlu/test_c_comm_init_op_mlu.sh | 3 +-- .../fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py | 6 +++--- python/paddle/fluid/tests/unittests/mlu/test_spawn_mlu.py | 7 ++++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_c_comm_init_op_mlu.sh b/python/paddle/fluid/tests/unittests/mlu/test_c_comm_init_op_mlu.sh index 97f21798c11..36fc85ba6da 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_c_comm_init_op_mlu.sh +++ b/python/paddle/fluid/tests/unittests/mlu/test_c_comm_init_op_mlu.sh @@ -17,5 +17,4 @@ set -e # use default values # FIXME: random fails on Unknown command lines -c (or -m). -launch_py=${PADDLE_BINARY_DIR}/python/paddle/distributed/launch.py -MLU_VISIBLE_DEVICES=0,1 python ${launch_py} c_comm_init_op_mlu.py +MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch c_comm_init_op_mlu.py diff --git a/python/paddle/fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py index 89475eb6985..1f12d47da42 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_hard_swish_op_mlu.py @@ -16,13 +16,13 @@ from __future__ import print_function import paddle.nn.functional as F import paddle.fluid as fluid import paddle +import sys + +sys.path.append("..") from op_test import OpTest import numpy as np import unittest -import sys - -sys.path.append("..") paddle.enable_static() SEED = 2021 diff --git a/python/paddle/fluid/tests/unittests/mlu/test_spawn_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_spawn_mlu.py index e52b5ee301c..fc1d62bfdad 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_spawn_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_spawn_mlu.py @@ -102,12 +102,13 @@ class TestSpawn(unittest.TestCase): self.assertEqual(nprocs, core.get_mlu_device_count()) def test_spawn(self): - context = dist.spawn(train, backend='cncl', nprocs=4) + num_devs = core.get_mlu_device_count() + context = dist.spawn(train, backend='cncl', nprocs=num_devs) rank_list = [] - for i in range(4): + for i in range(num_devs): rank_list.append(context.return_queues[i].get()) rank_list.sort() - self.assertEqual(rank_list, list(range(4))) + self.assertEqual(rank_list, list(range(num_devs))) if __name__ == '__main__': -- GitLab