test_run.py 6.7 KB
Newer Older
K
kuizhiqing 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
K
kuizhiqing 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
K
kuizhiqing 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
K
kuizhiqing 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os
import random
K
kuizhiqing 已提交
17
import subprocess
18
import sys
19
import tempfile
20
import unittest
K
kuizhiqing 已提交
21 22 23 24 25 26 27
from os import listdir
from os.path import isfile, join

pyname = 'train.py'
colpyfile = '''# train.py for unitest
import os
env = os.environ.copy()
28 29 30 31
if "PADDLE_AUTO_PARALLEL_CONFIG" not in env:
    assert "PADDLE_MASTER" in env
    assert "PADDLE_GLOBAL_RANK" in env
    assert "PADDLE_LOCAL_RANK" in env
K
kuizhiqing 已提交
32 33 34 35 36 37 38 39 40
assert "PADDLE_GLOBAL_SIZE" in env
assert "PADDLE_LOCAL_SIZE" in env
'''

pspyfile = '''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_PSERVERS_IP_PORT_LIST" in env
assert "PADDLE_TRAINER_ENDPOINTS" in env
41
assert "PADDLE_ROLE" in env
K
kuizhiqing 已提交
42 43 44 45 46 47 48 49 50 51 52
#assert "PADDLE_RANK" in env
'''


def write_file(name, ct):
    with open(name, "w") as f:
        f.write(ct)


def get_files(pth, prefix):
    return [
53 54
        f
        for f in listdir(pth)
55 56 57
        if isfile(join(pth, f))
        and not f.endswith('gpu.log')
        and not f.startswith('envlog')
K
kuizhiqing 已提交
58 59 60 61 62
    ]


class Collective_Test(unittest.TestCase):
    def setUp(self):
63 64 65 66 67 68
        self.temp_dir = tempfile.TemporaryDirectory()
        self.path = os.path.join(self.temp_dir.name, pyname)
        write_file(self.path, colpyfile)

    def tearDown(self):
        self.temp_dir.cleanup()
K
kuizhiqing 已提交
69 70

    def pdrun(self, args, env=None):
71
        cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.launch"]
K
kuizhiqing 已提交
72 73
        if args:
            cmd.extend(args.split(" "))
74
        cmd.extend([self.path])
K
kuizhiqing 已提交
75 76 77 78
        env = os.environ.copy()
        # virtual devies for testing
        env.update({'CUDA_VISIBLE_DEVICES': '0,1,2,3,4,5,6,7'})
        proc = subprocess.Popen(cmd, env=env)
K
kuizhiqing 已提交
79 80 81
        return proc

    def test_collective_1(self):
82
        log_dir = tempfile.TemporaryDirectory()
83
        args = f"--job_id test1 --log_dir {log_dir.name}"
K
kuizhiqing 已提交
84 85 86
        p = self.pdrun(args)
        p.wait()
        self.assertTrue(p.poll() == 0)
87
        log_dir.cleanup()
K
kuizhiqing 已提交
88 89

    def test_collective_2(self):
90 91
        log_dir = tempfile.TemporaryDirectory()
        args = "--job_id test2 --devices 0,1,2 --log_dir {}".format(
92 93
            log_dir.name
        )
K
kuizhiqing 已提交
94 95 96 97
        p = self.pdrun(args)
        p.wait()
        self.assertTrue(p.poll() == 0)

98
        c = get_files(log_dir.name, 'test2')
K
kuizhiqing 已提交
99
        self.assertTrue(len(c) == 4)
100
        log_dir.cleanup()
K
kuizhiqing 已提交
101 102

    def test_collective_3(self):
103
        log_dir = tempfile.TemporaryDirectory()
K
kuizhiqing 已提交
104
        port = random.randrange(6000, 8000)
105 106 107
        args = "--job_id test3 --devices 0,1 --log_dir {} --master 127.0.0.1:{} --nnodes 2"
        p1 = self.pdrun(args.format(log_dir.name + "/1", port))
        p2 = self.pdrun(args.format(log_dir.name + "/2", port))
K
kuizhiqing 已提交
108 109 110 111 112
        p1.wait()
        p2.wait()
        self.assertTrue(p1.poll() == 0)
        self.assertTrue(p2.poll() == 0)

113 114 115 116 117
        c1 = get_files(log_dir.name + "/1", 'test3')
        c2 = get_files(log_dir.name + "/2", 'test3')
        print(c1)
        self.assertTrue(len(c1) == 3)
        self.assertTrue(len(c2) == 3)
118
        log_dir.cleanup()
K
kuizhiqing 已提交
119

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    def test_collective_4(self):
        log_dir = tempfile.TemporaryDirectory()
        config_dir = tempfile.TemporaryDirectory()
        config_path = os.path.join(config_dir.name, 'auto_parallel_config.json')
        with open(config_path, 'w') as wobj:
            wobj.write(
                '{\"tuner_save_path\":\"parallel_strategy.pkl\",\"tuner_load_path\":\"parallel_strategy.pkl\",\"tuner_run_mode\":\"tuner_and_run\"}'
            )
        port = random.randrange(6000, 8000)
        args = "--job_id test4 --devices 0,1 --log_dir {} --auto_parallel_config {}"
        p1 = self.pdrun(args.format(log_dir.name + "/1", config_path))
        p1.wait()
        self.assertTrue(p1.poll() == 0)

        c1 = get_files(log_dir.name + "/1", 'test4')
        print(c1)
        self.assertTrue(len(c1) == 4)
        log_dir.cleanup()
        config_dir.cleanup()

K
kuizhiqing 已提交
140 141 142

class PS_Test(unittest.TestCase):
    def setUp(self):
143 144 145 146 147 148
        self.temp_dir = tempfile.TemporaryDirectory()
        self.path = os.path.join(self.temp_dir.name, pyname)
        write_file(self.path, pspyfile)

    def tearDown(self):
        self.temp_dir.cleanup()
K
kuizhiqing 已提交
149 150

    def pdrun(self, args, env=None):
151
        cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.launch"]
K
kuizhiqing 已提交
152 153
        if args:
            cmd.extend(args.split(" "))
154
        cmd.extend([self.path])
K
kuizhiqing 已提交
155 156 157 158
        proc = subprocess.Popen(cmd, env)
        return proc

    def test_ps_1(self):
159
        log_dir = tempfile.TemporaryDirectory()
160
        args = f"--run_mode ps --log_dir {log_dir.name}"
K
kuizhiqing 已提交
161 162 163
        p = self.pdrun(args)
        p.wait()
        self.assertTrue(p.poll() == 0)
164
        log_dir.cleanup()
K
kuizhiqing 已提交
165 166

    def test_ps_2(self):
167
        log_dir = tempfile.TemporaryDirectory()
168 169 170 171 172
        args = (
            "--job_id ps2 --server_num=2 --trainer_num=2 --log_dir {}".format(
                log_dir.name
            )
        )
K
kuizhiqing 已提交
173 174 175 176
        p = self.pdrun(args)
        p.wait()
        self.assertTrue(p.poll() == 0)

177
        c = get_files(log_dir.name, 'ps2')
K
kuizhiqing 已提交
178
        self.assertTrue(len(c) == 5)
179
        log_dir.cleanup()
K
kuizhiqing 已提交
180 181

    def test_ps_3(self):
182
        log_dir = tempfile.TemporaryDirectory()
K
kuizhiqing 已提交
183
        port = random.randrange(6000, 8000)
184 185 186
        args = "--job_id ps3 --log_dir {} --master 127.0.0.1:{} --nnodes 2 --server_num=1 --trainer_num=1"
        p1 = self.pdrun(args.format(log_dir.name + "/1", port))
        p2 = self.pdrun(args.format(log_dir.name + "/2", port))
K
kuizhiqing 已提交
187 188 189 190 191
        p1.wait()
        p2.wait()
        self.assertTrue(p1.poll() == 0)
        self.assertTrue(p2.poll() == 0)

192 193 194 195 196
        c1 = get_files(log_dir.name + "/1", 'ps3')
        c2 = get_files(log_dir.name + "/2", 'ps3')
        print(c1)
        self.assertTrue(len(c1) == 3)
        self.assertTrue(len(c2) == 3)
197
        log_dir.cleanup()
K
kuizhiqing 已提交
198 199

    def test_ps_4(self):
200 201
        log_dir = tempfile.TemporaryDirectory()
        args = "--job_id ps4 --log_dir {} --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903".format(
202 203
            log_dir.name
        )
K
kuizhiqing 已提交
204 205 206 207
        p1 = self.pdrun(args)
        p1.wait()
        self.assertTrue(p1.poll() == 0)

208
        c = get_files(log_dir.name, 'ps4')
209
        print(c)
K
kuizhiqing 已提交
210
        self.assertTrue(len(c) == 5)
211
        log_dir.cleanup()
K
kuizhiqing 已提交
212 213 214 215


if __name__ == '__main__':
    unittest.main()