test_dist_se_resnext.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import argparse
import time
import math

import unittest
import os
import signal
import subprocess


class TestDistSeResneXt2x2(unittest.TestCase):
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
        self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
        self._python_interp = "python"

    def start_pserver(self):
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
        ps0_cmd = "%s dist_se_resnext.py pserver %s 0 %s %d TRUE" % \
            (self._python_interp, self._ps_endpoints, ps0_ep, self._trainers)
        ps1_cmd = "%s dist_se_resnext.py pserver %s 0 %s %d TRUE" % \
            (self._python_interp, self._ps_endpoints, ps1_ep, self._trainers)

        ps0_proc = subprocess.Popen(
            ps0_cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        ps1_proc = subprocess.Popen(
            ps1_cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        return ps0_proc, ps1_proc

    def _wait_ps_ready(self, pid):
        retry_times = 20
        while True:
            assert retry_times >= 0, "wait ps ready failed"
            time.sleep(3)
            try:
                # the listen_and_serv_op would touch a file which contains the listen port
                # on the /tmp directory until it was ready to process all the RPC call.
                os.stat("/tmp/paddle.%d.port" % pid)
                return
            except os.error:
                retry_times -= 1

59
    def non_test_with_place(self):
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
        # *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
        required_envs = {
            "PATH": os.getenv("PATH"),
            "PYTHONPATH": os.getenv("PYTHONPATH"),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH"),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15"
        }
        # Run local to get a base line
        env_local = {"CUDA_VISIBLE_DEVICES": "0"}
        env_local.update(required_envs)
        local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
            (self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1)
        local_proc = subprocess.Popen(
            local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local)
        local_proc.wait()
        local_ret = local_proc.stdout.read()

        # Run dist train to compare with local results
        ps0, ps1 = self.start_pserver()
        self._wait_ps_ready(ps0.pid)
        self._wait_ps_ready(ps1.pid)

        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
        tr0_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d TRUE" % \
            (self._python_interp, self._ps_endpoints, ps0_ep, self._trainers)
        tr1_cmd = "%s dist_se_resnext.py trainer %s 1 %s %d TRUE" % \
            (self._python_interp, self._ps_endpoints, ps1_ep, self._trainers)

        env0 = {"CUDA_VISIBLE_DEVICES": "0"}
        env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        env0.update(required_envs)
        env1.update(required_envs)
        FNULL = open(os.devnull, 'w')

        tr0_proc = subprocess.Popen(
            tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0)
        tr1_proc = subprocess.Popen(
            tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1)

        tr0_proc.wait()
        tr1_proc.wait()
        loss_data0 = tr0_proc.stdout.read()
        lines = loss_data0.split("\n")
        dist_first_loss = eval(lines[0].replace(" ", ","))[0]
        dist_last_loss = eval(lines[1].replace(" ", ","))[0]

        local_lines = local_ret.split("\n")
        local_first_loss = eval(local_lines[0])[0]
        local_last_loss = eval(local_lines[1])[0]

        self.assertAlmostEqual(local_first_loss, dist_first_loss)
        self.assertAlmostEqual(local_last_loss, dist_last_loss)

        # check tr0_out
        # FIXME: ensure the server process is killed
        # replace with ps0.terminate()
        os.kill(ps0.pid, signal.SIGKILL)
        os.kill(ps1.pid, signal.SIGKILL)
        FNULL.close()


if __name__ == "__main__":
    unittest.main()