test_listen_and_serv_op.py 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.fluid as fluid
import os
import signal
import subprocess
import time
import unittest
from multiprocessing import Process
23
from op_test import OpTest
24 25


Y
Yancey1989 已提交
26
def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
    x = fluid.layers.data(name='x', shape=[1], dtype='float32')
    y_predict = fluid.layers.fc(input=x, size=1, act=None)
    y = fluid.layers.data(name='y', shape=[1], dtype='float32')

    # loss function
    cost = fluid.layers.square_error_cost(input=y_predict, label=y)
    avg_cost = fluid.layers.mean(cost)

    # optimizer
    sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
    sgd_optimizer.minimize(avg_cost)

    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
    exe = fluid.Executor(place)

Y
Yancey1989 已提交
42 43
    pserver_endpoints = ip + ":" + port
    current_endpoint = ip + ":" + port
44 45 46 47 48 49 50 51 52 53 54 55 56 57
    t = fluid.DistributeTranspiler()
    t.transpile(
        trainer_id,
        pservers=pserver_endpoints,
        trainers=trainers,
        sync_mode=sync_mode)
    pserver_prog = t.get_pserver_program(current_endpoint)
    pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
    exe.run(pserver_startup)
    exe.run(pserver_prog)


class TestListenAndServOp(OpTest):
    def setUp(self):
Y
Yancey1989 已提交
58
        self.ps_timeout = 5
59
        self.ip = "127.0.0.1"
Y
yi.wu 已提交
60
        self.port = "0"
Y
Yancey1989 已提交
61
        self.trainers = 1
Y
yi.wu 已提交
62
        self.trainer_id = 0
63 64 65 66

    def _start_pserver(self, use_cuda, sync_mode):
        p = Process(
            target=run_pserver,
Y
Yancey1989 已提交
67
            args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
68
                  self.trainer_id))
Y
yi.wu 已提交
69
        p.daemon = True
70
        p.start()
Y
yi.wu 已提交
71
        return p
Y
Yancey1989 已提交
72 73

    def _wait_ps_ready(self, pid):
74 75
        start_left_time = self.ps_timeout
        sleep_time = 0.5
Y
Yancey1989 已提交
76
        while True:
77 78
            assert start_left_time >= 0, "wait ps ready failed"
            time.sleep(sleep_time)
Y
Yancey1989 已提交
79 80 81 82 83 84
            try:
                # the listen_and_serv_op would touch a file which contains the listen port
                # on the /tmp directory until it was ready to process all the RPC call.
                os.stat("/tmp/paddle.%d.port" % pid)
                return
            except os.error:
85
                start_left_time -= sleep_time
86

Y
Yancey1989 已提交
87 88 89 90
    def test_rpc_interfaces(self):
        # TODO(Yancey1989): need to make sure the rpc interface correctly.
        pass

91 92
    def test_handle_signal_in_serv_op(self):
        # run pserver on CPU in sync mode
Y
yi.wu 已提交
93 94
        p1 = self._start_pserver(False, True)
        self._wait_ps_ready(p1.pid)
95

Y
Yancey1989 已提交
96
        # raise SIGTERM to pserver
Q
Qiyang Min 已提交
97
        os.kill(p1.pid, signal.SIGINT)
Y
yi.wu 已提交
98
        p1.join()
99 100

        # run pserver on CPU in async mode
Y
yi.wu 已提交
101 102
        p2 = self._start_pserver(False, False)
        self._wait_ps_ready(p2.pid)
103 104

        # raise SIGTERM to pserver
Q
Qiyang Min 已提交
105
        os.kill(p2.pid, signal.SIGTERM)
Y
yi.wu 已提交
106
        p2.join()
107 108 109 110


if __name__ == '__main__':
    unittest.main()