test_dist_fleet_base.py 18.8 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16 17 18
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid as fluid
19
from paddle.distributed.fleet.utils.ps_util import DistributedInfer
20

21 22 23
"""
    high level unit test for distribute fleet.
"""
24

25
import argparse
T
tangwei12 已提交
26
import os
27
import shutil
28
import socket
29 30
import subprocess
import sys
31
import tempfile
32
import time
33
import unittest
34
from contextlib import closing
T
tangwei12 已提交
35

36
import paddle
37

T
tangwei12 已提交
38 39
paddle.enable_static()

C
Chengmo 已提交
40 41
__all__ = ['FleetDistRunnerBase', 'TestFleetBase', 'runtime_main']

T
tangwei12 已提交
42 43
RUN_STEP = 5
LEARNING_RATE = 0.01
44
DIST_UT_PORT = 0
T
tangwei12 已提交
45 46


47
class FleetDistRunnerBase:
48
    """
49 50 51
    run_pserver,run_trainer : after init role, using transpiler split program
    net : implment by child class, the network of model
    do training : exe run program
52 53
    """

54 55 56
    def __init__(self):
        self._exe = None

57
    def build_role(self, args):
58

59 60
        if args.role.upper() == "PSERVER":
            role = role_maker.UserDefinedRoleMaker(
61
                is_collective=False,
62
                init_gloo=False,
63
                path=args.gloo_path,
64 65
                current_id=args.current_id,
                role=role_maker.Role.SERVER,
66
                worker_endpoints=args.trainer_endpoints.split(","),
67 68
                server_endpoints=args.endpoints.split(","),
            )
69 70
        else:
            role = role_maker.UserDefinedRoleMaker(
71
                is_collective=False,
72
                init_gloo=False,
73
                path=args.gloo_path,
74 75
                current_id=args.current_id,
                role=role_maker.Role.WORKER,
76
                worker_endpoints=args.trainer_endpoints.split(","),
77 78
                server_endpoints=args.endpoints.split(","),
            )
79
        self.role = role
80 81 82
        return role

    def build_strategy(self, args):
83 84 85 86
        if args.mode == "sync":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = False
        elif args.mode == "async":
87 88
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
1
123malin 已提交
89
        elif args.mode == "geo":
90 91 92 93 94
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
            self.strategy.a_sync_configs = {
                "k_steps": args.geo_sgd_need_push_nums
            }
95 96 97 98
        elif args.mode == "auto":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.auto = True

99 100 101 102
        self.dump_param = os.getenv("dump_param", "").split(",")
        self.dump_fields = os.getenv("dump_fields", "").split(",")
        self.dump_fields_path = os.getenv("dump_fields_path", "")
        debug = int(os.getenv("Debug", "0"))
103
        # TODO(update strategy to support dump params)
104
        if False:  # debug:
105 106 107 108 109 110 111
            self.strategy.set_debug_opt(
                {
                    "dump_param": self.dump_param,
                    "dump_fields": self.dump_fields,
                    "dump_fields_path": self.dump_fields_path,
                }
            )
112

1
123malin 已提交
113 114
        return self.strategy

115
    def build_optimizer(self, avg_cost, strategy):
C
Chengmo 已提交
116
        use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
117
        grad_clip = None
C
Chengmo 已提交
118 119 120
        if use_grad_clip:
            # 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
            if use_grad_clip == 1:
121
                grad_clip = paddle.nn.ClipGradByValue(min=-5.0, max=5.0)
C
Chengmo 已提交
122
            elif use_grad_clip == 2:
123
                grad_clip = paddle.nn.ClipGradByNorm(2.0)
C
Chengmo 已提交
124
            elif use_grad_clip == 3:
125
                grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
C
Chengmo 已提交
126

127
        use_decay = int(os.getenv("USE_DECAY", "0"))
128
        if use_decay:
129
            scheduler = paddle.optimizer.lr.ExponentialDecay(
130 131
                learning_rate=LEARNING_RATE, gamma=0.999, verbose=True
            )
132
            optimizer = fluid.optimizer.SGD(scheduler, grad_clip=grad_clip)
133 134
            """
            # learning rate decay method before 2.0
135 136 137 138 139
            optimizer = fluid.optimizer.SGD(
                learning_rate=fluid.layers.exponential_decay(
                    learning_rate=LEARNING_RATE,
                    decay_steps=500,
                    decay_rate=0.969,
140
                    staircase=True))
141
            """
142
        else:
143
            optimizer = fluid.optimizer.SGD(LEARNING_RATE, grad_clip=grad_clip)
144
        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
T
tangwei12 已提交
145 146
        optimizer.minimize(avg_cost)

147
    def run_pserver(self, args):
T
tangwei12 已提交
148 149 150
        fleet.init_server()
        fleet.run_server()

1
123malin 已提交
151 152 153 154 155
    def run_dataset_trainer(self, args):
        out = self.do_dataset_training(fleet)

    def run_pyreader_trainer(self, args):
        out = self.do_pyreader_training(fleet)
T
tangwei12 已提交
156

157
    def net(self, args, batch_size=4, lr=0.01):
T
tangwei12 已提交
158
        raise NotImplementedError(
159 160
            "get_model should be implemented by child classes."
        )
T
tangwei12 已提交
161

162 163 164 165 166 167 168 169 170 171
    def get_executor(self):
        if self._exe is None:
            device_env = os.getenv("DEVICE", 'cpu')
            if device_env == 'cpu':
                device = fluid.CPUPlace()
            elif device_env == 'gpu':
                device = fluid.CUDAPlace(0)
            self._exe = fluid.Executor(device)
        return self._exe

1
123malin 已提交
172
    def do_dataset_training(self, fleet):
T
tangwei12 已提交
173
        raise NotImplementedError(
174 175
            "do_dataset_training should be implemented by child classes."
        )
1
123malin 已提交
176 177 178

    def do_pyreader_training(self, fleet):
        raise NotImplementedError(
179 180
            "do_pyreader_training should be implemented by child classes."
        )
T
tangwei12 已提交
181

T
tangwei12 已提交
182 183
    def do_distributed_testing(self, fleet):
        raise NotImplementedError(
184 185
            "do_distributed_testing should be implemented by child classes."
        )
T
tangwei12 已提交
186

T
tangwei12 已提交
187 188

class TestFleetBase(unittest.TestCase):
189
    """
190 191
    start_pserver,start_trainer : add start cmd to test
    run_cluster : using multi process to test distribute program
192 193
    """

T
tangwei12 已提交
194 195 196
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

197 198 199 200
    def tearDown(self):
        t = time.time() - self.startTime
        print('%s: %.3f' % (self.__class__.__name__, t))

T
tangwei12 已提交
201
    def setUp(self):
202 203
        self.startTime = time.time()

1
123malin 已提交
204 205
        self._mode = "sync"
        self._reader = "pyreader"
T
tangwei12 已提交
206 207
        self._trainers = 2
        self._pservers = 2
T
tangwei12 已提交
208
        self._need_test = 0
209
        self._model_dir = ""
T
tangwei12 已提交
210
        self._port_set = set()
211 212 213 214 215 216 217 218

        global DIST_UT_PORT
        if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
            DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))

        if DIST_UT_PORT:
            print("set begin_port:", DIST_UT_PORT)
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
219 220 221
                DIST_UT_PORT,
                DIST_UT_PORT + 1,
            )
222
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
223 224 225
                DIST_UT_PORT + 2,
                DIST_UT_PORT + 3,
            )
226
            DIST_UT_PORT += 4
227 228
        else:
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
229 230 231
                self._find_free_port(),
                self._find_free_port(),
            )
232
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
233 234 235
                self._find_free_port(),
                self._find_free_port(),
            )
236

T
tangwei12 已提交
237
        self._python_interp = sys.executable
238
        self._geo_sgd_need_push_nums = 5
C
Chengmo 已提交
239
        self._grad_clip_mode = 0
T
tangwei12 已提交
240 241 242 243
        self._setup_config()

    def _find_free_port(self):
        def __free_port():
244 245 246
            with closing(
                socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            ) as s:
T
tangwei12 已提交
247 248 249 250 251 252 253 254 255 256 257 258
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port

    def _start_pserver(self, cmd, required_envs):
        ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        ps0_err_log = os.path.join(log_dirname, log_prename + "ps0_stderr.log")
        ps1_err_log = os.path.join(log_dirname, log_prename + "ps1_stderr.log")
        ps0_out_log = os.path.join(log_dirname, log_prename + "ps0_stdout.log")
        ps1_out_log = os.path.join(log_dirname, log_prename + "ps1_stdout.log")

        ps0_err = open(ps0_err_log, "wb+")
        ps1_err = open(ps1_err_log, "wb+")

        ps0_out = open(ps0_out_log, "wb+")
        ps1_out = open(ps1_out_log, "wb+")
T
tangwei12 已提交
275

276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
        ps0_proc = subprocess.Popen(
            ps0_cmd.strip().split(" "),
            stdout=ps0_out,
            stderr=ps0_err,
            env=required_envs,
        )

        ps1_proc = subprocess.Popen(
            ps1_cmd.strip().split(" "),
            stdout=ps1_out,
            stderr=ps1_err,
            env=required_envs,
        )

        return (
            (ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log),
            (ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log),
        )
T
tangwei12 已提交
294 295 296 297

    def _start_trainer(self, cmd, required_envs):
        tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
298 299 300 301 302 303 304 305 306 307
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        tr0_err_log = os.path.join(log_dirname, log_prename + "tr0_stderr.log")
        tr1_err_log = os.path.join(log_dirname, log_prename + "tr1_stderr.log")
        tr0_out_log = os.path.join(log_dirname, log_prename + "tr0_stdout.log")
        tr1_out_log = os.path.join(log_dirname, log_prename + "tr1_stdout.log")
T
tangwei12 已提交
308

T
tangwei12 已提交
309 310 311 312 313
        tr0_err = open(tr0_err_log, "wb+")
        tr1_err = open(tr1_err_log, "wb+")

        tr0_out = open(tr0_out_log, "wb+")
        tr1_out = open(tr1_out_log, "wb+")
314

315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=tr0_out,
            stderr=tr0_err,
            env=required_envs,
        )

        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=tr1_out,
            stderr=tr1_err,
            env=required_envs,
        )

        return (
            (tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log),
            (tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log),
        )
T
tangwei12 已提交
333 334

    def _run_cluster(self, model, envs):
335
        env = {'GRAD_CLIP': str(self._grad_clip_mode), 'WITH_DISTRIBUTE': 'ON'}
336
        python_path = self._python_interp
337 338
        gloo_path = tempfile.mkdtemp()

339 340 341
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
            python_path += " -m coverage run --branch -p"
342
        env.update(envs)
T
tangwei12 已提交
343

T
tangwei12 已提交
344
        tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
345 346 347 348 349 350 351 352 353 354 355
            python_path,
            model,
            self._ps_endpoints,
            self._tr_endpoints,
            self._trainers,
            self._mode,
            self._geo_sgd_need_push_nums,
            self._reader,
            gloo_path,
            self._need_test,
        )
T
tangwei12 已提交
356

T
tangwei12 已提交
357
        ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
358 359 360 361 362 363 364 365 366 367 368
            python_path,
            model,
            self._ps_endpoints,
            self._tr_endpoints,
            self._trainers,
            self._mode,
            self._geo_sgd_need_push_nums,
            self._reader,
            gloo_path,
            self._need_test,
        )
369

370 371 372 373
        if self._model_dir:
            tr_cmd += " --model_dir {}".format(self._model_dir)
            ps_cmd += " --model_dir {}".format(self._model_dir)

T
tangwei12 已提交
374
        # Run dist train to compare with local results
T
tangwei12 已提交
375 376 377 378 379 380 381 382
        ps0, ps1 = self._start_pserver(ps_cmd, env)
        tr0, tr1 = self._start_trainer(tr_cmd, env)

        ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log = ps0
        ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log = ps1

        tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log = tr0
        tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log = tr1
T
tangwei12 已提交
383 384

        # Wait until trainer process terminate
385
        # time_out = 120
386
        time_out = 60
T
tangwei12 已提交
387
        cur_time = 0
388

T
tangwei12 已提交
389
        while True:
T
tangwei12 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402 403
            stat0 = tr0_proc.poll()
            stat1 = tr1_proc.poll()

            if stat0 is not None and stat1 is not None:
                break
            else:
                time.sleep(0.5)
                cur_time += 0.5

            if cur_time >= time_out:
                tr0_proc.terminate()
                tr1_proc.terminate()
                tr0_proc.wait()
                tr1_proc.wait()
T
tangwei12 已提交
404 405
                break

T
tangwei12 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
        tr0_ret = tr0_proc.returncode
        tr1_ret = tr1_proc.returncode

        ps0_proc.kill()
        ps1_proc.kill()
        ps0_proc.wait()
        ps1_proc.wait()

        def is_listen_failed(logx):
            is_lf = False

            listen_rgx = "Fail to listen"

            with open(logx, "r") as rb:
                for line in rb.readlines():
                    if listen_rgx in line:
                        is_lf = True
                        break
            return is_lf

        def catlog(logx):
            basename = os.path.basename(logx)
428 429 430 431 432
            print(
                "\n================== Error {} begin =====================".format(
                    basename
                )
            )
T
tangwei12 已提交
433
            os.system("cat {}".format(logx))
434 435 436 437 438
            print(
                "================== Error {} end =====================\n".format(
                    basename
                )
            )
T
tangwei12 已提交
439 440 441 442 443 444

        if tr0_ret != 0 or tr1_ret != 0:
            if is_listen_failed(ps0_err) or is_listen_failed(ps1_err):
                print("find parameter server port bind failed, skip the error")
                tr0_ret, tr1_ret = 0, 0
            else:
445 446 447 448 449 450
                for out, err in [
                    (ps0_out_log, ps0_err_log),
                    (ps1_out_log, ps1_err_log),
                    (tr0_out_log, tr0_err_log),
                    (tr1_out_log, tr1_err_log),
                ]:
T
tangwei12 已提交
451 452 453 454
                    catlog(out)
                    catlog(err)

        for pipe in [
455 456 457 458 459 460 461 462
            tr0_err,
            tr0_out,
            tr1_err,
            tr1_out,
            ps0_err,
            ps0_out,
            ps1_err,
            ps1_out,
T
tangwei12 已提交
463 464
        ]:
            pipe.close()
T
tangwei12 已提交
465

466
        shutil.rmtree(gloo_path)
T
tangwei12 已提交
467

C
Chengmo 已提交
468 469
        self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
        self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
T
tangwei12 已提交
470

T
tangwei12 已提交
471 472
        return 0, 0

473 474 475
    def check_with_place(
        self, model_file, delta=1e-3, check_error_log=False, need_envs={}
    ):
T
tangwei12 已提交
476 477 478 479 480
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
481
            "http_proxy": "",
T
tangwei12 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494
        }

        required_envs.update(need_envs)

        if check_error_log:
            required_envs["GLOG_v"] = "3"
            required_envs["GLOG_logtostderr"] = "1"

        tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)


def runtime_main(test_class):
    parser = argparse.ArgumentParser(description='Run Fleet test.')
495 496 497
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer']
    )
T
tangwei12 已提交
498
    parser.add_argument('--endpoints', type=str, required=False, default="")
499 500 501
    parser.add_argument(
        '--trainer_endpoints', type=str, required=False, default=""
    )
502
    parser.add_argument('--gloo_path', type=str, required=False, default="")
T
tangwei12 已提交
503 504
    parser.add_argument('--current_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
1
123malin 已提交
505
    parser.add_argument('--mode', type=str, required=False, default='geo')
506 507 508
    parser.add_argument(
        '--geo_sgd_need_push_nums', type=int, required=False, default=2
    )
1
123malin 已提交
509
    parser.add_argument('--reader', type=str, required=False, default='dataset')
T
tangwei12 已提交
510
    parser.add_argument('--test', type=int, required=False, default=0)
511
    parser.add_argument('--model_dir', type=str, required=False, default="")
T
tangwei12 已提交
512 513 514
    args = parser.parse_args()

    model = test_class()
515
    role = model.build_role(args)
516

T
tangwei12 已提交
517
    # for distributed inference
518 519 520
    if args.test and args.model_dir != "":
        avg_cost = model.net(args, is_train=False)
        dist_infer = DistributedInfer()
521 522 523 524 525 526
        dist_infer.init_distributed_infer_env(
            exe=model.get_executor(),
            loss=model.avg_cost,
            role_maker=role,
            dirname=args.model_dir,
        )
T
tangwei12 已提交
527

528 529
        if fleet.is_worker():
            with paddle.static.program_guard(
530 531
                main_program=dist_infer.get_dist_infer_program()
            ):
532 533
                model.do_distributed_testing(fleet)
                fleet.stop_worker()
T
tangwei12 已提交
534 535 536 537
            return

        if fleet.is_server():
            return
538

539 540 541 542
    fleet.init(role)
    strategy = model.build_strategy(args)
    avg_cost = model.net(args)
    model.build_optimizer(avg_cost, strategy)
543

T
tangwei12 已提交
544 545 546
    if args.role == "pserver":
        model.run_pserver(args)
    else:
1
123malin 已提交
547 548 549 550
        if args.reader == "dataset":
            model.run_dataset_trainer(args)
        else:
            model.run_pyreader_trainer(args)
T
tangwei12 已提交
551 552

        if args.test:
553 554 555
            test_origin_program = paddle.static.Program()
            test_startup_program = paddle.static.Program()
            with paddle.static.program_guard(
556 557 558
                main_program=test_origin_program,
                startup_program=test_startup_program,
            ):
559
                with paddle.utils.unique_name.guard():
T
tangwei12 已提交
560
                    avg_cost = model.net(args, is_train=False)
561 562 563 564
            dist_infer = DistributedInfer(
                main_program=test_origin_program,
                startup_program=test_startup_program,
            )
565
            with paddle.static.program_guard(
566 567
                main_program=dist_infer.get_dist_infer_program()
            ):
568
                model.do_distributed_testing(fleet)
T
tangwei12 已提交
569
        fleet.stop_worker()