test_communication_api_base.py 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import os
import shutil
18 19 20 21
import subprocess
import sys
import tempfile
import unittest
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36


class CommunicationTestDistBase(unittest.TestCase):
    def setUp(self, save_log_dir=None, num_of_devices=2, timeout=120):
        self._python_interp = sys.executable
        self._save_log_dir = save_log_dir
        self._log_dir = tempfile.TemporaryDirectory()
        self._num_of_devices = num_of_devices
        self._device_list = [str(i) for i in range(num_of_devices)]
        self._timeout = timeout
        self._seeds = [i + 10 for i in range(num_of_devices)]
        self._devices = ','.join(self._device_list)

    def run_test_case(self, script_file, user_defined_envs=None):
        runtime_envs = os.environ
37 38
        if user_defined_envs is not None:
            runtime_envs.update(user_defined_envs)
39 40 41 42 43
        runtime_envs["CUDA_VISIBLE_DEVICES"] = self._devices
        start_command = f"{self._python_interp} -u -m paddle.distributed.launch --log_dir {self._log_dir.name} --devices {self._devices} {script_file}"
        start_command_list = start_command.strip().split()

        try:
44 45 46 47 48 49
            self._launcher = subprocess.run(
                start_command_list,
                env=runtime_envs,
                timeout=self._timeout,
                check=True,
            )
50 51
        except subprocess.TimeoutExpired as err:
            raise TimeoutError(
52 53 54 55
                "Timeout while running command {}, try to set a longer period, {} is not enough.".format(
                    err.cmd, err.timeout
                )
            )
56 57
        except subprocess.CalledProcessError as err:
            raise RuntimeError(
58 59 60 61
                "Error occurs when running this test case. The return code of command {} is {}".format(
                    err.cmd, err.returncode
                )
            )
62 63 64 65 66 67

    def tearDown(self):
        if self._save_log_dir:
            temp_log_dir_name = os.path.basename(self._log_dir.name)
            dir_name = os.path.join(self._save_log_dir, temp_log_dir_name)
            if not os.path.isdir(dir_name):
68
                print(f"The running logs will copy to {dir_name}")
69 70 71
                shutil.copytree(self._log_dir.name, dir_name)
            else:
                raise RuntimeError(
72
                    f"Directory {dir_name} exists, failed to save log."
73
                )
74 75 76


def gen_product_envs_list(default_envs, changeable_envs):
77
    envs_list = []
78 79 80 81 82
    for values in itertools.product(*changeable_envs.values()):
        envs = dict(zip(changeable_envs.keys(), values))
        envs.update(default_envs)
        envs_list.append(envs)
    return envs_list