From a6747a6ef1e06d0b90389a7b5e4a5df3f9454d89 Mon Sep 17 00:00:00 2001 From: Dong Daxiang <35550832+guru4elephant@users.noreply.github.com> Date: Sat, 2 Nov 2019 12:03:26 +0800 Subject: [PATCH] =?UTF-8?q?add=20launch=5Fps=20module=20so=20that=20we=20c?= =?UTF-8?q?an=20launch=20a=20parameter=20server=20trainin=E2=80=A6=20(#209?= =?UTF-8?q?36)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add launch_ps module so that we can launch a parameter server training job 1) a user can specify worker_num and server_num 2) parameter server can be killed after all workers exit 3) unit test is added test=develop --- python/paddle/distributed/launch_ps.py | 38 +++++++-------- .../fluid/tests/unittests/CMakeLists.txt | 2 + .../tests/unittests/fleet_ps_training.py | 46 +++++++++++++++++++ .../fluid/tests/unittests/test_launch_ps.sh | 11 +++++ 4 files changed, 78 insertions(+), 19 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/fleet_ps_training.py create mode 100644 python/paddle/fluid/tests/unittests/test_launch_ps.sh diff --git a/python/paddle/distributed/launch_ps.py b/python/paddle/distributed/launch_ps.py index 4a6885c888..f8489965e7 100644 --- a/python/paddle/distributed/launch_ps.py +++ b/python/paddle/distributed/launch_ps.py @@ -100,18 +100,16 @@ def start_procs(args): user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")] for i in range(server_num): current_env.update({ - "PADDLE_TRAINERS_NUM": str(server_num), - "PADDLE_PORT": ",".join(user_endpoints_port), - #"POD_IP": user_endpoints_ips[i], - "CURRENT_ENDPOINT": - user_endpoints_ips[i] + ":" + user_endpoints_port[i], - "PADDLE_PSERVERS": ",".join(user_endpoints_ips), - "PADDLE_TRAINING_ROLE": "PSERVER" + "PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints, + "PADDLE_PORT": user_endpoints_port[i], + "TRAINING_ROLE": "PSERVER", + "PADDLE_TRAINERS_NUM": str(worker_num), + "POD_IP": user_endpoints_ips[i] }) + cmd = [sys.executable, "-u", args.training_script ] + args.training_script_args cmds.append(cmd) - print(cmd) if args.log_dir is not None: os.system("mkdir -p {}".format(args.log_dir)) fn = open("%s/serverlog.%d" % (args.log_dir, i), "w") @@ -123,15 +121,13 @@ def start_procs(args): for i in range(worker_num): current_env.update({ - "PADDLE_PSERVERS": ",".join(user_endpoints_ips), - "PADDLE_PORT": ",".join(user_endpoints_port), + "PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints, "PADDLE_TRAINERS_NUM": str(worker_num), - "PADDLE_TRAINING_ROLE": "TRAINER", + "TRAINING_ROLE": "TRAINER", "PADDLE_TRAINER_ID": str(i) }) cmd = [sys.executable, "-u", args.training_script ] + args.training_script_args - print(cmd) cmds.append(cmd) if args.log_dir is not None: os.system("mkdir -p {}".format(args.log_dir)) @@ -142,16 +138,20 @@ def start_procs(args): proc = subprocess.Popen(cmd, env=current_env) procs.append(proc) - for i in range(0, len(procs)): - proc = procs[i] - - proc.wait() + # only wait worker to finish here + for i, proc in enumerate(procs): + if i < server_num: + continue + procs[i].wait() if len(log_fns) > 0: log_fns[i].close() - if proc.returncode != 0: - raise subprocess.CalledProcessError( - returncode=procs[i].returncode, cmd=cmds[i]) + print("all workers exit, going to finish parameter server", file=sys.stderr) + for i in range(server_num): + if len(log_fns) > 0: + log_fns[i].close() + procs[i].terminate() + print("all parameter server are killed", file=sys.stderr) def launch(): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 10735ab030..374cd136ce 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -18,6 +18,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_nce_remote_table_op) list(APPEND MIXED_DIST_TEST_OPS test_hsigmoid_remote_table_op) list(APPEND MIXED_DIST_TEST_OPS test_lookup_remote_table_op) list(APPEND MIXED_DIST_TEST_OPS test_launch) +list(APPEND MIXED_DIST_TEST_OPS test_launch_ps) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -259,6 +260,7 @@ if(WITH_DISTRIBUTE) if(NOT APPLE) bash_test_modules(test_listen_and_serv_op MODULES test_listen_and_serv.sh) bash_test_modules(test_launch MODULES test_launch.sh) + bash_test_modules(test_launch_ps MODULES test_launch_ps.sh) set(dist_ut_port 1000) foreach(TEST_OP ${DIST_TEST_OPS}) diff --git a/python/paddle/fluid/tests/unittests/fleet_ps_training.py b/python/paddle/fluid/tests/unittests/fleet_ps_training.py new file mode 100644 index 0000000000..a9e9675a61 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fleet_ps_training.py @@ -0,0 +1,46 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from utils import gen_data +from nets import mlp +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet +from paddle.fluid.incubate.fleet.base import role_maker + +input_x = fluid.layers.data(name="x", shape=[32], dtype='float32') +input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + +cost = mlp(input_x, input_y) +optimizer = fluid.optimizer.Adagrad(learning_rate=0.01) + +role = role_maker.PaddleCloudRoleMaker() +fleet.init(role) + +optimizer = fleet.distributed_optimizer(optimizer) +optimizer.minimize(cost) + +if fleet.is_server(): + fleet.init_server() + fleet.run_server() +elif fleet.is_worker(): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fleet.startup_program) + step = 1001 + for i in range(step): + cost_val = exe.run(program=fleet.main_program, + feed=gen_data(), + fetch_list=[cost.name]) + print("worker_index: %d, step%d cost = %f" % + (fleet.worker_index(), i, cost_val[0])) diff --git a/python/paddle/fluid/tests/unittests/test_launch_ps.sh b/python/paddle/fluid/tests/unittests/test_launch_ps.sh new file mode 100644 index 0000000000..0bd722af03 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_launch_ps.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e +# use default values +python -m paddle.distributed.launch_ps fleet_ps_training.py 2> ut.elog + +if grep -q "server are killed" ut.elog; then + echo "succeed" +else + echo "failed" + exit -1 +fi -- GitLab