test_dist_fleet_base.py 18.7 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
"""
16
high level unit test for distribute fleet.
17
"""
18

19
import argparse
T
tangwei12 已提交
20
import os
21
import shutil
22
import socket
23 24
import subprocess
import sys
25
import tempfile
26
import time
27
import unittest
28
from contextlib import closing
T
tangwei12 已提交
29

30
import paddle
31 32 33 34
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid as fluid
from paddle.distributed.fleet.utils.ps_util import DistributedInfer
35

T
tangwei12 已提交
36 37
paddle.enable_static()

C
Chengmo 已提交
38 39
__all__ = ['FleetDistRunnerBase', 'TestFleetBase', 'runtime_main']

T
tangwei12 已提交
40 41
RUN_STEP = 5
LEARNING_RATE = 0.01
42
DIST_UT_PORT = 0
T
tangwei12 已提交
43 44


45
class FleetDistRunnerBase:
46
    """
47 48 49
    run_pserver,run_trainer : after init role, using transpiler split program
    net : implment by child class, the network of model
    do training : exe run program
50 51
    """

52 53 54
    def __init__(self):
        self._exe = None

55
    def build_role(self, args):
56

57 58
        if args.role.upper() == "PSERVER":
            role = role_maker.UserDefinedRoleMaker(
59
                is_collective=False,
60
                init_gloo=False,
61
                path=args.gloo_path,
62 63
                current_id=args.current_id,
                role=role_maker.Role.SERVER,
64
                worker_endpoints=args.trainer_endpoints.split(","),
65 66
                server_endpoints=args.endpoints.split(","),
            )
67 68
        else:
            role = role_maker.UserDefinedRoleMaker(
69
                is_collective=False,
70
                init_gloo=False,
71
                path=args.gloo_path,
72 73
                current_id=args.current_id,
                role=role_maker.Role.WORKER,
74
                worker_endpoints=args.trainer_endpoints.split(","),
75 76
                server_endpoints=args.endpoints.split(","),
            )
77
        self.role = role
78 79 80
        return role

    def build_strategy(self, args):
81 82 83 84
        if args.mode == "sync":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = False
        elif args.mode == "async":
85 86
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
1
123malin 已提交
87
        elif args.mode == "geo":
88 89 90 91 92
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.a_sync = True
            self.strategy.a_sync_configs = {
                "k_steps": args.geo_sgd_need_push_nums
            }
93 94 95 96
        elif args.mode == "auto":
            self.strategy = paddle.distributed.fleet.DistributedStrategy()
            self.strategy.auto = True

97 98 99 100
        self.dump_param = os.getenv("dump_param", "").split(",")
        self.dump_fields = os.getenv("dump_fields", "").split(",")
        self.dump_fields_path = os.getenv("dump_fields_path", "")
        debug = int(os.getenv("Debug", "0"))
101
        # TODO(update strategy to support dump params)
102
        if False:  # debug:
103 104 105 106 107 108 109
            self.strategy.set_debug_opt(
                {
                    "dump_param": self.dump_param,
                    "dump_fields": self.dump_fields,
                    "dump_fields_path": self.dump_fields_path,
                }
            )
110

1
123malin 已提交
111 112
        return self.strategy

113
    def build_optimizer(self, avg_cost, strategy):
C
Chengmo 已提交
114
        use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
115
        grad_clip = None
C
Chengmo 已提交
116 117 118
        if use_grad_clip:
            # 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
            if use_grad_clip == 1:
119
                grad_clip = paddle.nn.ClipGradByValue(min=-5.0, max=5.0)
C
Chengmo 已提交
120
            elif use_grad_clip == 2:
121
                grad_clip = paddle.nn.ClipGradByNorm(2.0)
C
Chengmo 已提交
122
            elif use_grad_clip == 3:
123
                grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
C
Chengmo 已提交
124

125
        use_decay = int(os.getenv("USE_DECAY", "0"))
126
        if use_decay:
127
            scheduler = paddle.optimizer.lr.ExponentialDecay(
128 129
                learning_rate=LEARNING_RATE, gamma=0.999, verbose=True
            )
130
            optimizer = fluid.optimizer.SGD(scheduler, grad_clip=grad_clip)
131 132
            """
            # learning rate decay method before 2.0
133 134 135 136 137
            optimizer = fluid.optimizer.SGD(
                learning_rate=fluid.layers.exponential_decay(
                    learning_rate=LEARNING_RATE,
                    decay_steps=500,
                    decay_rate=0.969,
138
                    staircase=True))
139
            """
140
        else:
141
            optimizer = fluid.optimizer.SGD(LEARNING_RATE, grad_clip=grad_clip)
142
        optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
T
tangwei12 已提交
143 144
        optimizer.minimize(avg_cost)

145
    def run_pserver(self, args):
T
tangwei12 已提交
146 147 148
        fleet.init_server()
        fleet.run_server()

1
123malin 已提交
149 150 151 152 153
    def run_dataset_trainer(self, args):
        out = self.do_dataset_training(fleet)

    def run_pyreader_trainer(self, args):
        out = self.do_pyreader_training(fleet)
T
tangwei12 已提交
154

155
    def net(self, args, batch_size=4, lr=0.01):
T
tangwei12 已提交
156
        raise NotImplementedError(
157 158
            "get_model should be implemented by child classes."
        )
T
tangwei12 已提交
159

160 161 162 163 164 165 166 167 168 169
    def get_executor(self):
        if self._exe is None:
            device_env = os.getenv("DEVICE", 'cpu')
            if device_env == 'cpu':
                device = fluid.CPUPlace()
            elif device_env == 'gpu':
                device = fluid.CUDAPlace(0)
            self._exe = fluid.Executor(device)
        return self._exe

1
123malin 已提交
170
    def do_dataset_training(self, fleet):
T
tangwei12 已提交
171
        raise NotImplementedError(
172 173
            "do_dataset_training should be implemented by child classes."
        )
1
123malin 已提交
174 175 176

    def do_pyreader_training(self, fleet):
        raise NotImplementedError(
177 178
            "do_pyreader_training should be implemented by child classes."
        )
T
tangwei12 已提交
179

T
tangwei12 已提交
180 181
    def do_distributed_testing(self, fleet):
        raise NotImplementedError(
182 183
            "do_distributed_testing should be implemented by child classes."
        )
T
tangwei12 已提交
184

T
tangwei12 已提交
185 186

class TestFleetBase(unittest.TestCase):
187
    """
188 189
    start_pserver,start_trainer : add start cmd to test
    run_cluster : using multi process to test distribute program
190 191
    """

T
tangwei12 已提交
192 193 194
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

195 196 197 198
    def tearDown(self):
        t = time.time() - self.startTime
        print('%s: %.3f' % (self.__class__.__name__, t))

T
tangwei12 已提交
199
    def setUp(self):
200 201
        self.startTime = time.time()

1
123malin 已提交
202 203
        self._mode = "sync"
        self._reader = "pyreader"
T
tangwei12 已提交
204 205
        self._trainers = 2
        self._pservers = 2
T
tangwei12 已提交
206
        self._need_test = 0
207
        self._model_dir = ""
T
tangwei12 已提交
208
        self._port_set = set()
209 210 211 212 213 214 215 216

        global DIST_UT_PORT
        if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
            DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))

        if DIST_UT_PORT:
            print("set begin_port:", DIST_UT_PORT)
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
217 218 219
                DIST_UT_PORT,
                DIST_UT_PORT + 1,
            )
220
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
221 222 223
                DIST_UT_PORT + 2,
                DIST_UT_PORT + 3,
            )
224
            DIST_UT_PORT += 4
225 226
        else:
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
227 228 229
                self._find_free_port(),
                self._find_free_port(),
            )
230
            self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
231 232 233
                self._find_free_port(),
                self._find_free_port(),
            )
234

T
tangwei12 已提交
235
        self._python_interp = sys.executable
236
        self._geo_sgd_need_push_nums = 5
C
Chengmo 已提交
237
        self._grad_clip_mode = 0
T
tangwei12 已提交
238 239 240 241
        self._setup_config()

    def _find_free_port(self):
        def __free_port():
242 243 244
            with closing(
                socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            ) as s:
T
tangwei12 已提交
245 246 247 248 249 250 251 252 253 254 255 256
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port

    def _start_pserver(self, cmd, required_envs):
        ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        ps0_err_log = os.path.join(log_dirname, log_prename + "ps0_stderr.log")
        ps1_err_log = os.path.join(log_dirname, log_prename + "ps1_stderr.log")
        ps0_out_log = os.path.join(log_dirname, log_prename + "ps0_stdout.log")
        ps1_out_log = os.path.join(log_dirname, log_prename + "ps1_stdout.log")

        ps0_err = open(ps0_err_log, "wb+")
        ps1_err = open(ps1_err_log, "wb+")

        ps0_out = open(ps0_out_log, "wb+")
        ps1_out = open(ps1_out_log, "wb+")
T
tangwei12 已提交
273

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        ps0_proc = subprocess.Popen(
            ps0_cmd.strip().split(" "),
            stdout=ps0_out,
            stderr=ps0_err,
            env=required_envs,
        )

        ps1_proc = subprocess.Popen(
            ps1_cmd.strip().split(" "),
            stdout=ps1_out,
            stderr=ps1_err,
            env=required_envs,
        )

        return (
            (ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log),
            (ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log),
        )
T
tangwei12 已提交
292 293 294 295

    def _start_trainer(self, cmd, required_envs):
        tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)

T
tangwei12 已提交
296 297 298 299 300 301 302 303 304 305
        log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
        log_prename = required_envs.get("LOG_PREFIX", "")

        if log_dirname:
            log_prename += "_"

        tr0_err_log = os.path.join(log_dirname, log_prename + "tr0_stderr.log")
        tr1_err_log = os.path.join(log_dirname, log_prename + "tr1_stderr.log")
        tr0_out_log = os.path.join(log_dirname, log_prename + "tr0_stdout.log")
        tr1_out_log = os.path.join(log_dirname, log_prename + "tr1_stdout.log")
T
tangwei12 已提交
306

T
tangwei12 已提交
307 308 309 310 311
        tr0_err = open(tr0_err_log, "wb+")
        tr1_err = open(tr1_err_log, "wb+")

        tr0_out = open(tr0_out_log, "wb+")
        tr1_out = open(tr1_out_log, "wb+")
312

313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=tr0_out,
            stderr=tr0_err,
            env=required_envs,
        )

        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=tr1_out,
            stderr=tr1_err,
            env=required_envs,
        )

        return (
            (tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log),
            (tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log),
        )
T
tangwei12 已提交
331 332

    def _run_cluster(self, model, envs):
333
        env = {'GRAD_CLIP': str(self._grad_clip_mode), 'WITH_DISTRIBUTE': 'ON'}
334
        python_path = self._python_interp
335 336
        gloo_path = tempfile.mkdtemp()

337 338 339
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
            python_path += " -m coverage run --branch -p"
340
        env.update(envs)
T
tangwei12 已提交
341

T
tangwei12 已提交
342
        tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
343 344 345 346 347 348 349 350 351 352 353
            python_path,
            model,
            self._ps_endpoints,
            self._tr_endpoints,
            self._trainers,
            self._mode,
            self._geo_sgd_need_push_nums,
            self._reader,
            gloo_path,
            self._need_test,
        )
T
tangwei12 已提交
354

T
tangwei12 已提交
355
        ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --test {9}".format(
356 357 358 359 360 361 362 363 364 365 366
            python_path,
            model,
            self._ps_endpoints,
            self._tr_endpoints,
            self._trainers,
            self._mode,
            self._geo_sgd_need_push_nums,
            self._reader,
            gloo_path,
            self._need_test,
        )
367

368 369 370 371
        if self._model_dir:
            tr_cmd += " --model_dir {}".format(self._model_dir)
            ps_cmd += " --model_dir {}".format(self._model_dir)

T
tangwei12 已提交
372
        # Run dist train to compare with local results
T
tangwei12 已提交
373 374 375 376 377 378 379 380
        ps0, ps1 = self._start_pserver(ps_cmd, env)
        tr0, tr1 = self._start_trainer(tr_cmd, env)

        ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log = ps0
        ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log = ps1

        tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log = tr0
        tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log = tr1
T
tangwei12 已提交
381 382

        # Wait until trainer process terminate
383
        # time_out = 120
384
        time_out = 60
T
tangwei12 已提交
385
        cur_time = 0
386

T
tangwei12 已提交
387
        while True:
T
tangwei12 已提交
388 389 390 391 392 393 394 395 396 397 398 399 400 401
            stat0 = tr0_proc.poll()
            stat1 = tr1_proc.poll()

            if stat0 is not None and stat1 is not None:
                break
            else:
                time.sleep(0.5)
                cur_time += 0.5

            if cur_time >= time_out:
                tr0_proc.terminate()
                tr1_proc.terminate()
                tr0_proc.wait()
                tr1_proc.wait()
T
tangwei12 已提交
402 403
                break

T
tangwei12 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
        tr0_ret = tr0_proc.returncode
        tr1_ret = tr1_proc.returncode

        ps0_proc.kill()
        ps1_proc.kill()
        ps0_proc.wait()
        ps1_proc.wait()

        def is_listen_failed(logx):
            is_lf = False

            listen_rgx = "Fail to listen"

            with open(logx, "r") as rb:
                for line in rb.readlines():
                    if listen_rgx in line:
                        is_lf = True
                        break
            return is_lf

        def catlog(logx):
            basename = os.path.basename(logx)
426 427 428 429 430
            print(
                "\n================== Error {} begin =====================".format(
                    basename
                )
            )
T
tangwei12 已提交
431
            os.system("cat {}".format(logx))
432 433 434 435 436
            print(
                "================== Error {} end =====================\n".format(
                    basename
                )
            )
T
tangwei12 已提交
437 438

        if tr0_ret != 0 or tr1_ret != 0:
439
            if is_listen_failed(ps0_err_log) or is_listen_failed(ps1_err_log):
T
tangwei12 已提交
440 441 442
                print("find parameter server port bind failed, skip the error")
                tr0_ret, tr1_ret = 0, 0
            else:
443 444 445 446 447 448
                for out, err in [
                    (ps0_out_log, ps0_err_log),
                    (ps1_out_log, ps1_err_log),
                    (tr0_out_log, tr0_err_log),
                    (tr1_out_log, tr1_err_log),
                ]:
T
tangwei12 已提交
449 450 451 452
                    catlog(out)
                    catlog(err)

        for pipe in [
453 454 455 456 457 458 459 460
            tr0_err,
            tr0_out,
            tr1_err,
            tr1_out,
            ps0_err,
            ps0_out,
            ps1_err,
            ps1_out,
T
tangwei12 已提交
461 462
        ]:
            pipe.close()
T
tangwei12 已提交
463

464
        shutil.rmtree(gloo_path)
T
tangwei12 已提交
465

C
Chengmo 已提交
466 467
        self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
        self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
T
tangwei12 已提交
468

T
tangwei12 已提交
469 470
        return 0, 0

471 472 473
    def check_with_place(
        self, model_file, delta=1e-3, check_error_log=False, need_envs={}
    ):
T
tangwei12 已提交
474 475 476 477 478
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
479
            "http_proxy": "",
T
tangwei12 已提交
480 481 482 483 484 485 486 487 488 489 490 491 492
        }

        required_envs.update(need_envs)

        if check_error_log:
            required_envs["GLOG_v"] = "3"
            required_envs["GLOG_logtostderr"] = "1"

        tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)


def runtime_main(test_class):
    parser = argparse.ArgumentParser(description='Run Fleet test.')
493 494 495
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer']
    )
T
tangwei12 已提交
496
    parser.add_argument('--endpoints', type=str, required=False, default="")
497 498 499
    parser.add_argument(
        '--trainer_endpoints', type=str, required=False, default=""
    )
500
    parser.add_argument('--gloo_path', type=str, required=False, default="")
T
tangwei12 已提交
501 502
    parser.add_argument('--current_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
1
123malin 已提交
503
    parser.add_argument('--mode', type=str, required=False, default='geo')
504 505 506
    parser.add_argument(
        '--geo_sgd_need_push_nums', type=int, required=False, default=2
    )
1
123malin 已提交
507
    parser.add_argument('--reader', type=str, required=False, default='dataset')
T
tangwei12 已提交
508
    parser.add_argument('--test', type=int, required=False, default=0)
509
    parser.add_argument('--model_dir', type=str, required=False, default="")
T
tangwei12 已提交
510 511 512
    args = parser.parse_args()

    model = test_class()
513
    role = model.build_role(args)
514

T
tangwei12 已提交
515
    # for distributed inference
516 517 518
    if args.test and args.model_dir != "":
        avg_cost = model.net(args, is_train=False)
        dist_infer = DistributedInfer()
519 520 521 522 523 524
        dist_infer.init_distributed_infer_env(
            exe=model.get_executor(),
            loss=model.avg_cost,
            role_maker=role,
            dirname=args.model_dir,
        )
T
tangwei12 已提交
525

526 527
        if fleet.is_worker():
            with paddle.static.program_guard(
528 529
                main_program=dist_infer.get_dist_infer_program()
            ):
530 531
                model.do_distributed_testing(fleet)
                fleet.stop_worker()
T
tangwei12 已提交
532 533 534 535
            return

        if fleet.is_server():
            return
536

537 538 539 540
    fleet.init(role)
    strategy = model.build_strategy(args)
    avg_cost = model.net(args)
    model.build_optimizer(avg_cost, strategy)
541

T
tangwei12 已提交
542 543 544
    if args.role == "pserver":
        model.run_pserver(args)
    else:
1
123malin 已提交
545 546 547 548
        if args.reader == "dataset":
            model.run_dataset_trainer(args)
        else:
            model.run_pyreader_trainer(args)
T
tangwei12 已提交
549 550

        if args.test:
551 552 553
            test_origin_program = paddle.static.Program()
            test_startup_program = paddle.static.Program()
            with paddle.static.program_guard(
554 555 556
                main_program=test_origin_program,
                startup_program=test_startup_program,
            ):
557
                with paddle.utils.unique_name.guard():
T
tangwei12 已提交
558
                    avg_cost = model.net(args, is_train=False)
559 560 561 562
            dist_infer = DistributedInfer(
                main_program=test_origin_program,
                startup_program=test_startup_program,
            )
563
            with paddle.static.program_guard(
564 565
                main_program=dist_infer.get_dist_infer_program()
            ):
566
                model.do_distributed_testing(fleet)
T
tangwei12 已提交
567
        fleet.stop_worker()