test_dist_fleet_base.py 18.8 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from paddle.distributed.fleet.utils.ps_util import DistributedInfer
16 17 18
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid as fluid
19
import paddle
20

21 22 23
"""
    high level unit test for distribute fleet.
"""
24

T
tangwei12 已提交
25 26
import os
import sys
27
import subprocess
T
tangwei12 已提交
28

29
import shutil
30 31 32 33
import argparse
from contextlib import closing
import socket
import time
34
import tempfile
35
import unittest
T
tangwei12 已提交
36

37
import paddle
38

T
tangwei12 已提交
39 40
paddle.enable_static()

C
Chengmo 已提交
41 42
__all__ = ['FleetDistRunnerBase', 'TestFleetBase', 'runtime_main']

T
tangwei12 已提交
43 44
RUN_STEP = 5
LEARNING_RATE = 0.01
45
DIST_UT_PORT = 0
T
tangwei12 已提交
46 47 48


class FleetDistRunnerBase(object):
49
    """
50 51 52
    run_pserver,run_trainer : after init role, using transpiler split program
    net : implment by child class, the network of model
    do training : exe run program
53 54
    """

55 56 57
    def __init__(self):
        self._exe = None

58
    def build_role(self, args):
59

60 61
        if args.role.upper() == "PSERVER":
            role = role_maker.UserDefinedRoleMaker(
62
                is_collective=False,
63
                init_gloo=False,
64
                path=args.gloo_path,
65 66
                current_id=args.current_id,
                role=role_maker.Role.SERVER,
67
                worker_endpoints=args.trainer_endpoints.split(","),
68 69
                server_endpoints=args.endpoints.split(","),
            )
70 71
        else:
            role = role_maker.UserDefinedRoleMaker(
72
                is_collective=False,
73
                init_gloo=False,
74
                path=args.gloo_path,
75 76
                current_id=args.current_id,
                role=role_maker.Role.WORKER,
77
                worker_endpoints=args.trainer_endpoints.split(","),
78 79
                server_endpoints=args.endpoints.split(","),
            )
80
        self.role = role
81 82 83
        return role

    def build_strategy(self, args):
84 85 86 87
        if args.mode == "sync":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = False
        elif args.mode == "async":
88 89
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
1
123malin 已提交
90
        elif args.mode == "geo":
91 92 93 94 95
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
            self.strategy.a_sync_configs = {
                "k_steps": args.geo_sgd_need_push_nums
            }
96 97 98 99
        elif args.mode == "auto":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.auto = True

100 101 102 103
        self.dump_param = os.getenv("dump_param", "").split(",")
        self.dump_fields = os.getenv("dump_fields", "").split(",")
        self.dump_fields_path = os.getenv("dump_fields_path", "")
        debug = int(os.getenv("Debug", "0"))
104
        # TODO(update strategy to support dump params)
105
        if False:  # debug:
106 107 108 109 110 111 112
            self.strategy.set_debug_opt(
                {
                    "dump_param": self.dump_param,
                    "dump_fields": self.dump_fields,
                    "dump_fields_path": self.dump_fields_path,
                }
            )
113

1
123malin 已提交
114 115
        return self.strategy

116
    def build_optimizer(self, avg_cost, strategy):
C
Chengmo 已提交
117
        use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
118
        grad_clip = None
C
Chengmo 已提交
119 120 121
        if use_grad_clip:
            # 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
            if use_grad_clip == 1:
122
                grad_clip = paddle.nn.ClipGradByValue(min=-5.0, max=5.0)
C
Chengmo 已提交
123
            elif use_grad_clip == 2:
124
                grad_clip = paddle.nn.ClipGradByNorm(2.0)
C
Chengmo 已提交
125
            elif use_grad_clip == 3:
126
                grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
C
Chengmo 已提交
127

128
        use_decay = int(os.getenv("USE_DECAY", "0"))
129
        if use_decay:
130
            scheduler = paddle.optimizer.lr.ExponentialDecay(
131 132
                learning_rate=LEARNING_RATE, gamma=0.999, verbose=True
            )
133
            optimizer = fluid.optimizer.SGD(scheduler, grad_clip=grad_clip)
134 135
            """
            # learning rate decay method before 2.0
136 137 138 139 140
            optimizer = fluid.optimizer.SGD(
                learning_rate=fluid.layers.exponential_decay(
                    learning_rate=LEARNING_RATE,
                    decay_steps=500,
                    decay_rate=0.969,
141
                    staircase=True))
142
            """
143
        else:
144
            optimizer = fluid.optimizer.SGD(LEARNING_RATE, grad_clip=grad_clip)
145
        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
T
tangwei12 已提交
146 147
        optimizer.minimize(avg_cost)

148
    def run_pserver(self, args):
T
tangwei12 已提交
149 150 151
        fleet.init_server()
        fleet.run_server()

1
123malin 已提交
152 153 154 155 156
    def run_dataset_trainer(self, args):
        out = self.do_dataset_training(fleet)

    def run_pyreader_trainer(self, args):
        out = self.do_pyreader_training(fleet)
T
tangwei12 已提交
157

158
    def net(self, args, batch_size=4, lr=0.01):
T
tangwei12 已提交
159
        raise NotImplementedError(
160 161
            "get_model should be implemented by child classes."
        )
T
tangwei12 已提交
162

163 164 165 166 167 168 169 170 171 172
    def get_executor(self):
        if self._exe is None:
            device_env = os.getenv("DEVICE", 'cpu')
            if device_env == 'cpu':
                device = fluid.CPUPlace()
            elif device_env == 'gpu':
                device = fluid.CUDAPlace(0)
            self._exe = fluid.Executor(device)
        return self._exe

1
123malin 已提交
173
    def do_dataset_training(self, fleet):
T
tangwei12 已提交
174
        raise NotImplementedError(
175 176
            "do_dataset_training should be implemented by child classes."
        )
1
123malin 已提交
177 178 179

    def do_pyreader_training(self, fleet):
        raise NotImplementedError(
180 181
            "do_pyreader_training should be implemented by child classes."
        )
T
tangwei12 已提交
182

T
tangwei12 已提交
183 184
    def do_distributed_testing(self, fleet):
        raise NotImplementedError(
185 186
            "do_distributed_testing should be implemented by child classes."
        )
T
tangwei12 已提交
187

T
tangwei12 已提交
188 189

class TestFleetBase(unittest.TestCase):
190
    """
191 192
    start_pserver,start_trainer : add start cmd to test
    run_cluster : using multi process to test distribute program
193 194
    """

T
tangwei12 已提交
195 196 197
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

198 199 200 201
    def tearDown(self):
        t = time.time() - self.startTime
        print('%s: %.3f' % (self.__class__.__name__, t))

T
tangwei12 已提交
202
    def setUp(self):
203 204
        self.startTime = time.time()

1
123malin 已提交
205 206
        self._mode = "sync"
        self._reader = "pyreader"
T
tangwei12 已提交
207 208
        self._trainers = 2
        self._pservers = 2
T
tangwei12 已提交
209
        self._need_test = 0
210
        self._model_dir = ""
T
tangwei12 已提交
211
        self._port_set = set()
212 213 214 215 216 217 218 219

        global DIST_UT_PORT
        if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
            DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))

        if DIST_UT_PORT:
            print("set begin_port:", DIST_UT_PORT)
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
220 221 222
                DIST_UT_PORT,
                DIST_UT_PORT + 1,
            )
223
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
224 225 226
                DIST_UT_PORT + 2,
                DIST_UT_PORT + 3,
            )
227
            DIST_UT_PORT += 4
228 229
        else:
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
230 231 232
                self._find_free_port(),
                self._find_free_port(),
            )
233
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
234 235 236
                self._find_free_port(),
                self._find_free_port(),
            )
237

T
tangwei12 已提交
238
        self._python_interp = sys.executable
239
        self._geo_sgd_need_push_nums = 5
C
Chengmo 已提交
240
        self._grad_clip_mode = 0
T
tangwei12 已提交
241 242 243 244
        self._setup_config()

    def _find_free_port(self):
        def __free_port():
245 246 247
            with closing(
                socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            ) as s:
T
tangwei12 已提交
248 249 250 251 252 253 254 255 256 257 258 259
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port

    def _start_pserver(self, cmd, required_envs):
        ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        ps0_err_log = os.path.join(log_dirname, log_prename + "ps0_stderr.log")
        ps1_err_log = os.path.join(log_dirname, log_prename + "ps1_stderr.log")
        ps0_out_log = os.path.join(log_dirname, log_prename + "ps0_stdout.log")
        ps1_out_log = os.path.join(log_dirname, log_prename + "ps1_stdout.log")

        ps0_err = open(ps0_err_log, "wb+")
        ps1_err = open(ps1_err_log, "wb+")

        ps0_out = open(ps0_out_log, "wb+")
        ps1_out = open(ps1_out_log, "wb+")
T
tangwei12 已提交
276

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
        ps0_proc = subprocess.Popen(
            ps0_cmd.strip().split(" "),
            stdout=ps0_out,
            stderr=ps0_err,
            env=required_envs,
        )

        ps1_proc = subprocess.Popen(
            ps1_cmd.strip().split(" "),
            stdout=ps1_out,
            stderr=ps1_err,
            env=required_envs,
        )

        return (
            (ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log),
            (ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log),
        )
T
tangwei12 已提交
295 296 297 298

    def _start_trainer(self, cmd, required_envs):
        tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
299 300 301 302 303 304 305 306 307 308
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        tr0_err_log = os.path.join(log_dirname, log_prename + "tr0_stderr.log")
        tr1_err_log = os.path.join(log_dirname, log_prename + "tr1_stderr.log")
        tr0_out_log = os.path.join(log_dirname, log_prename + "tr0_stdout.log")
        tr1_out_log = os.path.join(log_dirname, log_prename + "tr1_stdout.log")
T
tangwei12 已提交
309

T
tangwei12 已提交
310 311 312 313 314
        tr0_err = open(tr0_err_log, "wb+")
        tr1_err = open(tr1_err_log, "wb+")

        tr0_out = open(tr0_out_log, "wb+")
        tr1_out = open(tr1_out_log, "wb+")
315

316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=tr0_out,
            stderr=tr0_err,
            env=required_envs,
        )

        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=tr1_out,
            stderr=tr1_err,
            env=required_envs,
        )

        return (
            (tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log),
            (tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log),
        )
T
tangwei12 已提交
334 335

    def _run_cluster(self, model, envs):
336
        env = {'GRAD_CLIP': str(self._grad_clip_mode), 'WITH_DISTRIBUTE': 'ON'}
337
        python_path = self._python_interp
338 339
        gloo_path = tempfile.mkdtemp()

340 341 342
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
            python_path += " -m coverage run --branch -p"
343
        env.update(envs)
T
tangwei12 已提交
344

T
tangwei12 已提交
345
        tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
346 347 348 349 350 351 352 353 354 355 356
            python_path,
            model,
            self._ps_endpoints,
            self._tr_endpoints,
            self._trainers,
            self._mode,
            self._geo_sgd_need_push_nums,
            self._reader,
            gloo_path,
            self._need_test,
        )
T
tangwei12 已提交
357

T
tangwei12 已提交
358
        ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
359 360 361 362 363 364 365 366 367 368 369
            python_path,
            model,
            self._ps_endpoints,
            self._tr_endpoints,
            self._trainers,
            self._mode,
            self._geo_sgd_need_push_nums,
            self._reader,
            gloo_path,
            self._need_test,
        )
370

371 372 373 374
        if self._model_dir:
            tr_cmd += " --model_dir {}".format(self._model_dir)
            ps_cmd += " --model_dir {}".format(self._model_dir)

T
tangwei12 已提交
375
        # Run dist train to compare with local results
T
tangwei12 已提交
376 377 378 379 380 381 382 383
        ps0, ps1 = self._start_pserver(ps_cmd, env)
        tr0, tr1 = self._start_trainer(tr_cmd, env)

        ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log = ps0
        ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log = ps1

        tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log = tr0
        tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log = tr1
T
tangwei12 已提交
384 385

        # Wait until trainer process terminate
386
        # time_out = 120
387
        time_out = 60
T
tangwei12 已提交
388
        cur_time = 0
389

T
tangwei12 已提交
390
        while True:
T
tangwei12 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404
            stat0 = tr0_proc.poll()
            stat1 = tr1_proc.poll()

            if stat0 is not None and stat1 is not None:
                break
            else:
                time.sleep(0.5)
                cur_time += 0.5

            if cur_time >= time_out:
                tr0_proc.terminate()
                tr1_proc.terminate()
                tr0_proc.wait()
                tr1_proc.wait()
T
tangwei12 已提交
405 406
                break

T
tangwei12 已提交
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
        tr0_ret = tr0_proc.returncode
        tr1_ret = tr1_proc.returncode

        ps0_proc.kill()
        ps1_proc.kill()
        ps0_proc.wait()
        ps1_proc.wait()

        def is_listen_failed(logx):
            is_lf = False

            listen_rgx = "Fail to listen"

            with open(logx, "r") as rb:
                for line in rb.readlines():
                    if listen_rgx in line:
                        is_lf = True
                        break
            return is_lf

        def catlog(logx):
            basename = os.path.basename(logx)
429 430 431 432 433
            print(
                "\n================== Error {} begin =====================".format(
                    basename
                )
            )
T
tangwei12 已提交
434
            os.system("cat {}".format(logx))
435 436 437 438 439
            print(
                "================== Error {} end =====================\n".format(
                    basename
                )
            )
T
tangwei12 已提交
440 441 442 443 444 445

        if tr0_ret != 0 or tr1_ret != 0:
            if is_listen_failed(ps0_err) or is_listen_failed(ps1_err):
                print("find parameter server port bind failed, skip the error")
                tr0_ret, tr1_ret = 0, 0
            else:
446 447 448 449 450 451
                for out, err in [
                    (ps0_out_log, ps0_err_log),
                    (ps1_out_log, ps1_err_log),
                    (tr0_out_log, tr0_err_log),
                    (tr1_out_log, tr1_err_log),
                ]:
T
tangwei12 已提交
452 453 454 455
                    catlog(out)
                    catlog(err)

        for pipe in [
456 457 458 459 460 461 462 463
            tr0_err,
            tr0_out,
            tr1_err,
            tr1_out,
            ps0_err,
            ps0_out,
            ps1_err,
            ps1_out,
T
tangwei12 已提交
464 465
        ]:
            pipe.close()
T
tangwei12 已提交
466

467
        shutil.rmtree(gloo_path)
T
tangwei12 已提交
468

C
Chengmo 已提交
469 470
        self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
        self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
T
tangwei12 已提交
471

T
tangwei12 已提交
472 473
        return 0, 0

474 475 476
    def check_with_place(
        self, model_file, delta=1e-3, check_error_log=False, need_envs={}
    ):
T
tangwei12 已提交
477 478 479 480 481
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
482
            "http_proxy": "",
T
tangwei12 已提交
483 484 485 486 487 488 489 490 491 492 493 494 495
        }

        required_envs.update(need_envs)

        if check_error_log:
            required_envs["GLOG_v"] = "3"
            required_envs["GLOG_logtostderr"] = "1"

        tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)


def runtime_main(test_class):
    parser = argparse.ArgumentParser(description='Run Fleet test.')
496 497 498
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer']
    )
T
tangwei12 已提交
499
    parser.add_argument('--endpoints', type=str, required=False, default="")
500 501 502
    parser.add_argument(
        '--trainer_endpoints', type=str, required=False, default=""
    )
503
    parser.add_argument('--gloo_path', type=str, required=False, default="")
T
tangwei12 已提交
504 505
    parser.add_argument('--current_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
1
123malin 已提交
506
    parser.add_argument('--mode', type=str, required=False, default='geo')
507 508 509
    parser.add_argument(
        '--geo_sgd_need_push_nums', type=int, required=False, default=2
    )
1
123malin 已提交
510
    parser.add_argument('--reader', type=str, required=False, default='dataset')
T
tangwei12 已提交
511
    parser.add_argument('--test', type=int, required=False, default=0)
512
    parser.add_argument('--model_dir', type=str, required=False, default="")
T
tangwei12 已提交
513 514 515
    args = parser.parse_args()

    model = test_class()
516
    role = model.build_role(args)
517

T
tangwei12 已提交
518
    # for distributed inference
519 520 521
    if args.test and args.model_dir != "":
        avg_cost = model.net(args, is_train=False)
        dist_infer = DistributedInfer()
522 523 524 525 526 527
        dist_infer.init_distributed_infer_env(
            exe=model.get_executor(),
            loss=model.avg_cost,
            role_maker=role,
            dirname=args.model_dir,
        )
T
tangwei12 已提交
528

529 530
        if fleet.is_worker():
            with paddle.static.program_guard(
531 532
                main_program=dist_infer.get_dist_infer_program()
            ):
533 534
                model.do_distributed_testing(fleet)
                fleet.stop_worker()
T
tangwei12 已提交
535 536 537 538
            return

        if fleet.is_server():
            return
539

540 541 542 543
    fleet.init(role)
    strategy = model.build_strategy(args)
    avg_cost = model.net(args)
    model.build_optimizer(avg_cost, strategy)
544

T
tangwei12 已提交
545 546 547
    if args.role == "pserver":
        model.run_pserver(args)
    else:
1
123malin 已提交
548 549 550 551
        if args.reader == "dataset":
            model.run_dataset_trainer(args)
        else:
            model.run_pyreader_trainer(args)
T
tangwei12 已提交
552 553

        if args.test:
554 555 556
            test_origin_program = paddle.static.Program()
            test_startup_program = paddle.static.Program()
            with paddle.static.program_guard(
557 558 559
                main_program=test_origin_program,
                startup_program=test_startup_program,
            ):
560
                with paddle.utils.unique_name.guard():
T
tangwei12 已提交
561
                    avg_cost = model.net(args, is_train=False)
562 563 564 565
            dist_infer = DistributedInfer(
                main_program=test_origin_program,
                startup_program=test_startup_program,
            )
566
            with paddle.static.program_guard(
567 568
                main_program=dist_infer.get_dist_infer_program()
            ):
569
                model.do_distributed_testing(fleet)
T
tangwei12 已提交
570
        fleet.stop_worker()