test_dist_fleet_base.py 18.9 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
from paddle.distributed.fleet.utils.ps_util import DistributedInfer
17 18 19 20
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid as fluid
21
import paddle
22 23 24
"""
    high level unit test for distribute fleet.
"""
25

T
tangwei12 已提交
26 27
import os
import sys
28
import subprocess
T
tangwei12 已提交
29

30 31 32
import six
import shutil
import numpy as np
33 34 35 36
import argparse
from contextlib import closing
import socket
import time
37
import tempfile
38
import unittest
T
tangwei12 已提交
39

40
import paddle
41

T
tangwei12 已提交
42 43
paddle.enable_static()

C
Chengmo 已提交
44 45
__all__ = ['FleetDistRunnerBase', 'TestFleetBase', 'runtime_main']

T
tangwei12 已提交
46 47
RUN_STEP = 5
LEARNING_RATE = 0.01
48
DIST_UT_PORT = 0
T
tangwei12 已提交
49 50 51


class FleetDistRunnerBase(object):
52 53 54 55 56 57
    """
        run_pserver,run_trainer : after init role, using transpiler split program
        net : implment by child class, the network of model
        do training : exe run program
    """

58 59 60
    def __init__(self):
        self._exe = None

61
    def build_role(self, args):
62

63 64
        if args.role.upper() == "PSERVER":
            role = role_maker.UserDefinedRoleMaker(
65
                is_collective=False,
66
                init_gloo=False,
67
                path=args.gloo_path,
68 69
                current_id=args.current_id,
                role=role_maker.Role.SERVER,
70
                worker_endpoints=args.trainer_endpoints.split(","),
71 72 73
                server_endpoints=args.endpoints.split(","))
        else:
            role = role_maker.UserDefinedRoleMaker(
74
                is_collective=False,
75
                init_gloo=False,
76
                path=args.gloo_path,
77 78
                current_id=args.current_id,
                role=role_maker.Role.WORKER,
79
                worker_endpoints=args.trainer_endpoints.split(","),
80
                server_endpoints=args.endpoints.split(","))
81
        self.role = role
82 83 84
        return role

    def build_strategy(self, args):
85 86 87 88
        if args.mode == "sync":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = False
        elif args.mode == "async":
89 90
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
1
123malin 已提交
91
        elif args.mode == "geo":
92 93 94 95 96
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
            self.strategy.a_sync_configs = {
                "k_steps": args.geo_sgd_need_push_nums
            }
97 98 99 100
        elif args.mode == "auto":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.auto = True

101 102 103 104
        self.dump_param = os.getenv("dump_param", "").split(",")
        self.dump_fields = os.getenv("dump_fields", "").split(",")
        self.dump_fields_path = os.getenv("dump_fields_path", "")
        debug = int(os.getenv("Debug", "0"))
105
        # TODO(update strategy to support dump params)
106
        if False:  # debug:
107
            self.strategy.set_debug_opt({
108 109 110 111 112 113
                "dump_param":
                self.dump_param,
                "dump_fields":
                self.dump_fields,
                "dump_fields_path":
                self.dump_fields_path
114 115
            })

1
123malin 已提交
116 117
        return self.strategy

118
    def build_optimizer(self, avg_cost, strategy):
C
Chengmo 已提交
119
        use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
120
        grad_clip = None
C
Chengmo 已提交
121 122 123
        if use_grad_clip:
            # 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
            if use_grad_clip == 1:
124
                grad_clip = paddle.nn.ClipGradByValue(min=-5.0, max=5.0)
C
Chengmo 已提交
125
            elif use_grad_clip == 2:
126
                grad_clip = paddle.nn.ClipGradByNorm(2.0)
C
Chengmo 已提交
127
            elif use_grad_clip == 3:
128
                grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
C
Chengmo 已提交
129

130
        use_decay = int(os.getenv("USE_DECAY", "0"))
131
        if use_decay:
132 133
            scheduler = paddle.optimizer.lr.ExponentialDecay(
                learning_rate=LEARNING_RATE, gamma=0.999, verbose=True)
134
            optimizer = fluid.optimizer.SGD(scheduler, grad_clip=grad_clip)
135 136
            """
            # learning rate decay method before 2.0
137 138 139 140 141
            optimizer = fluid.optimizer.SGD(
                learning_rate=fluid.layers.exponential_decay(
                    learning_rate=LEARNING_RATE,
                    decay_steps=500,
                    decay_rate=0.969,
142 143
                    staircase=True)) 
            """
144
        else:
145
            optimizer = fluid.optimizer.SGD(LEARNING_RATE, grad_clip=grad_clip)
146
        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
T
tangwei12 已提交
147 148
        optimizer.minimize(avg_cost)

149
    def run_pserver(self, args):
T
tangwei12 已提交
150 151 152
        fleet.init_server()
        fleet.run_server()

1
123malin 已提交
153 154 155 156 157
    def run_dataset_trainer(self, args):
        out = self.do_dataset_training(fleet)

    def run_pyreader_trainer(self, args):
        out = self.do_pyreader_training(fleet)
T
tangwei12 已提交
158

159
    def net(self, args, batch_size=4, lr=0.01):
T
tangwei12 已提交
160 161 162
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

163 164 165 166 167 168 169 170 171 172
    def get_executor(self):
        if self._exe is None:
            device_env = os.getenv("DEVICE", 'cpu')
            if device_env == 'cpu':
                device = fluid.CPUPlace()
            elif device_env == 'gpu':
                device = fluid.CUDAPlace(0)
            self._exe = fluid.Executor(device)
        return self._exe

1
123malin 已提交
173
    def do_dataset_training(self, fleet):
T
tangwei12 已提交
174
        raise NotImplementedError(
1
123malin 已提交
175 176 177 178 179
            "do_dataset_training should be implemented by child classes.")

    def do_pyreader_training(self, fleet):
        raise NotImplementedError(
            "do_pyreader_training should be implemented by child classes.")
T
tangwei12 已提交
180

T
tangwei12 已提交
181 182 183 184
    def do_distributed_testing(self, fleet):
        raise NotImplementedError(
            "do_distributed_testing should be implemented by child classes.")

T
tangwei12 已提交
185 186

class TestFleetBase(unittest.TestCase):
187 188 189 190 191
    """
        start_pserver,start_trainer : add start cmd to test
        run_cluster : using multi process to test distribute program
    """

T
tangwei12 已提交
192 193 194
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

195 196 197 198
    def tearDown(self):
        t = time.time() - self.startTime
        print('%s: %.3f' % (self.__class__.__name__, t))

T
tangwei12 已提交
199
    def setUp(self):
200 201
        self.startTime = time.time()

1
123malin 已提交
202 203
        self._mode = "sync"
        self._reader = "pyreader"
T
tangwei12 已提交
204 205
        self._trainers = 2
        self._pservers = 2
T
tangwei12 已提交
206
        self._need_test = 0
207
        self._model_dir = ""
T
tangwei12 已提交
208
        self._port_set = set()
209 210 211 212 213 214 215 216 217

        global DIST_UT_PORT
        if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
            DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))

        if DIST_UT_PORT:
            print("set begin_port:", DIST_UT_PORT)
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
                DIST_UT_PORT, DIST_UT_PORT + 1)
218 219 220
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
                DIST_UT_PORT + 2, DIST_UT_PORT + 3)
            DIST_UT_PORT += 4
221 222 223
        else:
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
                self._find_free_port(), self._find_free_port())
224 225
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
                self._find_free_port(), self._find_free_port())
226

T
tangwei12 已提交
227
        self._python_interp = sys.executable
228
        self._geo_sgd_need_push_nums = 5
C
Chengmo 已提交
229
        self._grad_clip_mode = 0
T
tangwei12 已提交
230 231 232
        self._setup_config()

    def _find_free_port(self):
233

T
tangwei12 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port

    def _start_pserver(self, cmd, required_envs):
        ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        ps0_err_log = os.path.join(log_dirname, log_prename + "ps0_stderr.log")
        ps1_err_log = os.path.join(log_dirname, log_prename + "ps1_stderr.log")
        ps0_out_log = os.path.join(log_dirname, log_prename + "ps0_stdout.log")
        ps1_out_log = os.path.join(log_dirname, log_prename + "ps1_stdout.log")

        ps0_err = open(ps0_err_log, "wb+")
        ps1_err = open(ps1_err_log, "wb+")

        ps0_out = open(ps0_out_log, "wb+")
        ps1_out = open(ps1_out_log, "wb+")
T
tangwei12 已提交
265

266 267 268 269
        ps0_proc = subprocess.Popen(ps0_cmd.strip().split(" "),
                                    stdout=ps0_out,
                                    stderr=ps0_err,
                                    env=required_envs)
T
tangwei12 已提交
270

271 272 273 274
        ps1_proc = subprocess.Popen(ps1_cmd.strip().split(" "),
                                    stdout=ps1_out,
                                    stderr=ps1_err,
                                    env=required_envs)
T
tangwei12 已提交
275 276 277

        return ((ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log),
                (ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log))
T
tangwei12 已提交
278 279 280 281

    def _start_trainer(self, cmd, required_envs):
        tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
282 283 284 285 286 287 288 289 290 291
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        tr0_err_log = os.path.join(log_dirname, log_prename + "tr0_stderr.log")
        tr1_err_log = os.path.join(log_dirname, log_prename + "tr1_stderr.log")
        tr0_out_log = os.path.join(log_dirname, log_prename + "tr0_stdout.log")
        tr1_out_log = os.path.join(log_dirname, log_prename + "tr1_stdout.log")
T
tangwei12 已提交
292

T
tangwei12 已提交
293 294 295 296 297
        tr0_err = open(tr0_err_log, "wb+")
        tr1_err = open(tr1_err_log, "wb+")

        tr0_out = open(tr0_out_log, "wb+")
        tr1_out = open(tr1_out_log, "wb+")
298

299 300 301 302
        tr0_proc = subprocess.Popen(tr0_cmd.strip().split(" "),
                                    stdout=tr0_out,
                                    stderr=tr0_err,
                                    env=required_envs)
T
tangwei12 已提交
303

304 305 306 307
        tr1_proc = subprocess.Popen(tr1_cmd.strip().split(" "),
                                    stdout=tr1_out,
                                    stderr=tr1_err,
                                    env=required_envs)
T
tangwei12 已提交
308

T
tangwei12 已提交
309 310
        return ((tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log),
                (tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log))
T
tangwei12 已提交
311 312

    def _run_cluster(self, model, envs):
313
        env = {'GRAD_CLIP': str(self._grad_clip_mode), 'WITH_DISTRIBUTE': 'ON'}
314
        python_path = self._python_interp
315 316
        gloo_path = tempfile.mkdtemp()

317 318 319
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
            python_path += " -m coverage run --branch -p"
320
        env.update(envs)
T
tangwei12 已提交
321

T
tangwei12 已提交
322
        tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
323 324
            python_path, model, self._ps_endpoints, self._tr_endpoints,
            self._trainers, self._mode, self._geo_sgd_need_push_nums,
T
tangwei12 已提交
325
            self._reader, gloo_path, self._need_test)
T
tangwei12 已提交
326

T
tangwei12 已提交
327
        ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
328 329
            python_path, model, self._ps_endpoints, self._tr_endpoints,
            self._trainers, self._mode, self._geo_sgd_need_push_nums,
T
tangwei12 已提交
330
            self._reader, gloo_path, self._need_test)
331

332 333 334 335
        if self._model_dir:
            tr_cmd += " --model_dir {}".format(self._model_dir)
            ps_cmd += " --model_dir {}".format(self._model_dir)

T
tangwei12 已提交
336
        # Run dist train to compare with local results
T
tangwei12 已提交
337 338 339 340 341 342 343 344
        ps0, ps1 = self._start_pserver(ps_cmd, env)
        tr0, tr1 = self._start_trainer(tr_cmd, env)

        ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log = ps0
        ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log = ps1

        tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log = tr0
        tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log = tr1
T
tangwei12 已提交
345 346

        # Wait until trainer process terminate
347 348
        #time_out = 120
        time_out = 60
T
tangwei12 已提交
349
        cur_time = 0
350

T
tangwei12 已提交
351
        while True:
T
tangwei12 已提交
352 353 354 355 356 357 358 359 360 361 362 363 364 365
            stat0 = tr0_proc.poll()
            stat1 = tr1_proc.poll()

            if stat0 is not None and stat1 is not None:
                break
            else:
                time.sleep(0.5)
                cur_time += 0.5

            if cur_time >= time_out:
                tr0_proc.terminate()
                tr1_proc.terminate()
                tr0_proc.wait()
                tr1_proc.wait()
T
tangwei12 已提交
366 367
                break

T
tangwei12 已提交
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
        tr0_ret = tr0_proc.returncode
        tr1_ret = tr1_proc.returncode

        ps0_proc.kill()
        ps1_proc.kill()
        ps0_proc.wait()
        ps1_proc.wait()

        def is_listen_failed(logx):
            is_lf = False

            listen_rgx = "Fail to listen"

            with open(logx, "r") as rb:
                for line in rb.readlines():
                    if listen_rgx in line:
                        is_lf = True
                        break
            return is_lf

        def catlog(logx):
            basename = os.path.basename(logx)
            print("\n================== Error {} begin =====================".
                  format(basename))
            os.system("cat {}".format(logx))
            print("================== Error {} end =====================\n".
                  format(basename))

        if tr0_ret != 0 or tr1_ret != 0:
            if is_listen_failed(ps0_err) or is_listen_failed(ps1_err):
                print("find parameter server port bind failed, skip the error")
                tr0_ret, tr1_ret = 0, 0
            else:
401 402 403 404
                for out, err in [(ps0_out_log, ps0_err_log),
                                 (ps1_out_log, ps1_err_log),
                                 (tr0_out_log, tr0_err_log),
                                 (tr1_out_log, tr1_err_log)]:
T
tangwei12 已提交
405 406 407 408 409 410 411 412
                    catlog(out)
                    catlog(err)

        for pipe in [
                tr0_err, tr0_out, tr1_err, tr1_out, ps0_err, ps0_out, ps1_err,
                ps1_out
        ]:
            pipe.close()
T
tangwei12 已提交
413

414
        shutil.rmtree(gloo_path)
T
tangwei12 已提交
415

C
Chengmo 已提交
416 417
        self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
        self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
T
tangwei12 已提交
418

T
tangwei12 已提交
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
        return 0, 0

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={}):
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
            "http_proxy": ""
        }

        required_envs.update(need_envs)

        if check_error_log:
            required_envs["GLOG_v"] = "3"
            required_envs["GLOG_logtostderr"] = "1"

        tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)


def runtime_main(test_class):
    parser = argparse.ArgumentParser(description='Run Fleet test.')
445 446 447 448
    parser.add_argument('--role',
                        type=str,
                        required=True,
                        choices=['pserver', 'trainer'])
T
tangwei12 已提交
449
    parser.add_argument('--endpoints', type=str, required=False, default="")
450 451 452 453
    parser.add_argument('--trainer_endpoints',
                        type=str,
                        required=False,
                        default="")
454
    parser.add_argument('--gloo_path', type=str, required=False, default="")
T
tangwei12 已提交
455 456
    parser.add_argument('--current_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
1
123malin 已提交
457
    parser.add_argument('--mode', type=str, required=False, default='geo')
458 459 460 461
    parser.add_argument('--geo_sgd_need_push_nums',
                        type=int,
                        required=False,
                        default=2)
1
123malin 已提交
462
    parser.add_argument('--reader', type=str, required=False, default='dataset')
T
tangwei12 已提交
463
    parser.add_argument('--test', type=int, required=False, default=0)
464
    parser.add_argument('--model_dir', type=str, required=False, default="")
T
tangwei12 已提交
465 466 467
    args = parser.parse_args()

    model = test_class()
468
    role = model.build_role(args)
469

T
tangwei12 已提交
470
    # for distributed inference
471 472 473
    if args.test and args.model_dir != "":
        avg_cost = model.net(args, is_train=False)
        dist_infer = DistributedInfer()
474 475 476 477
        dist_infer.init_distributed_infer_env(exe=model.get_executor(),
                                              loss=model.avg_cost,
                                              role_maker=role,
                                              dirname=args.model_dir)
T
tangwei12 已提交
478

479 480 481 482 483
        if fleet.is_worker():
            with paddle.static.program_guard(
                    main_program=dist_infer.get_dist_infer_program()):
                model.do_distributed_testing(fleet)
                fleet.stop_worker()
T
tangwei12 已提交
484 485 486 487
            return

        if fleet.is_server():
            return
488

489 490 491 492
    fleet.init(role)
    strategy = model.build_strategy(args)
    avg_cost = model.net(args)
    model.build_optimizer(avg_cost, strategy)
493

T
tangwei12 已提交
494 495 496
    if args.role == "pserver":
        model.run_pserver(args)
    else:
1
123malin 已提交
497 498 499 500
        if args.reader == "dataset":
            model.run_dataset_trainer(args)
        else:
            model.run_pyreader_trainer(args)
T
tangwei12 已提交
501 502

        if args.test:
503 504 505
            test_origin_program = paddle.static.Program()
            test_startup_program = paddle.static.Program()
            with paddle.static.program_guard(
T
tangwei12 已提交
506 507
                    main_program=test_origin_program,
                    startup_program=test_startup_program):
508
                with paddle.utils.unique_name.guard():
T
tangwei12 已提交
509
                    avg_cost = model.net(args, is_train=False)
510 511
            dist_infer = DistributedInfer(main_program=test_origin_program,
                                          startup_program=test_startup_program)
512 513 514
            with paddle.static.program_guard(
                    main_program=dist_infer.get_dist_infer_program()):
                model.do_distributed_testing(fleet)
T
tangwei12 已提交
515
        fleet.stop_worker()