test_dist_base.py 64.8 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import tempfile
X
Xin Pan 已提交
16

17
import ast
X
Xin Pan 已提交
18 19 20 21
import unittest
import os
import sys
import subprocess
22
import six
W
Wu Yi 已提交
23
import argparse
W
Wu Yi 已提交
24
import pickle
25
import random
W
Wu Yi 已提交
26
import numpy as np
27
import time
28 29

import paddle
30
import paddle.fluid as fluid
31
from paddle.fluid import compiler
32
import paddle.fluid.dygraph as dygraph
33
from paddle.fluid.framework import _test_eager_guard
34 35 36
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
import paddle.fluid.incubate.fleet.base.role_maker as role_maker

Y
Yan Xu 已提交
37
RUN_STEP = 5
38
DEFAULT_BATCH_SIZE = 2
39
DIST_UT_PORT = 0
40

T
typhoonzero 已提交
41

42
def print_to_out(out_losses):
T
tianshuo78520a 已提交
43
    sys.stdout.buffer.write(pickle.dumps(out_losses))
44 45 46


def print_to_err(class_name, log_str):
47 48
    localtime = time.asctime(time.localtime(time.time()))
    print_str = localtime + "\t" + class_name + "\t" + log_str
T
tianshuo78520a 已提交
49
    sys.stderr.buffer.write(pickle.dumps(print_str))
G
guru4elephant 已提交
50 51


52 53 54 55
def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


T
typhoonzero 已提交
56
class TestDistRunnerBase(object):
57

W
Wu Yi 已提交
58 59 60
    def get_model(self,
                  batch_size=DEFAULT_BATCH_SIZE,
                  lr=0.1,
61
                  single_device=False,
J
Jiangxinz 已提交
62 63
                  use_dgc=False,
                  dist_strategy=None):
T
typhoonzero 已提交
64 65 66
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

67
    @staticmethod
W
Wu Yi 已提交
68 69 70 71 72
    def get_transpiler(trainer_id,
                       main_program,
                       pserver_endpoints,
                       trainers,
                       sync_mode,
73
                       dc_asgd=False,
74
                       current_endpoint=None,
T
tangwei12 已提交
75 76
                       nccl_comm_num=1,
                       hogwild_mode=False):
T
typhoonzero 已提交
77
        # NOTE: import fluid until runtime, or else forking processes will cause error.
78
        config = fluid.DistributeTranspilerConfig()
W
Wu Yi 已提交
79
        config.enable_dc_asgd = dc_asgd
80
        config.sync_mode = sync_mode
T
tangwei12 已提交
81 82
        config.runtime_split_send_recv = hogwild_mode

83 84
        if nccl_comm_num > 1:
            config.nccl_comm_num = nccl_comm_num
85
        # config.runtime_split_send_recv = True
86
        t = fluid.DistributeTranspiler(config=config)
87 88 89 90 91 92
        t.transpile(trainer_id=trainer_id,
                    program=main_program,
                    pservers=pserver_endpoints,
                    trainers=trainers,
                    sync_mode=sync_mode,
                    current_endpoint=current_endpoint)
T
typhoonzero 已提交
93 94
        return t

95 96 97 98 99 100 101 102 103
    @staticmethod
    def get_lr_scheduler(program):
        lr_sheduler = None
        if hasattr(program, 'lr_sheduler'):
            from paddle.optimizer.lr import LRScheduler
            lr_sheduler = program.lr_sheduler
            assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
        return lr_sheduler

W
Wu Yi 已提交
104
    def run_pserver(self, args):
W
Wu Yi 已提交
105
        self.lr = args.lr
106
        self.get_model(batch_size=args.batch_size)
107
        # NOTE: pserver should not call memory optimize
T
tangwei12 已提交
108

109 110 111 112 113 114 115
        t = self.get_transpiler(trainer_id=args.trainer_id,
                                main_program=fluid.default_main_program(),
                                pserver_endpoints=args.endpoints,
                                trainers=args.trainers,
                                sync_mode=args.sync_mode,
                                dc_asgd=args.dc_asgd,
                                hogwild_mode=args.hogwild)
W
Wu Yi 已提交
116 117 118
        pserver_prog = t.get_pserver_program(args.current_endpoint)
        startup_prog = t.get_startup_program(args.current_endpoint,
                                             pserver_prog)
Y
Yancey1989 已提交
119

T
typhoonzero 已提交
120 121 122
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)
123
        print_to_err(type(self).__name__, "run pserver startup program done.")
T
typhoonzero 已提交
124
        exe.run(pserver_prog)
125
        print_to_err(type(self).__name__, "run pserver main program done.")
T
typhoonzero 已提交
126

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    def run_pipeline_trainer(self, args):
        self.lr = args.lr

        dist_strategy = DistributedStrategy()
        test_program, avg_cost, train_reader, test_reader, batch_acc, predict, data_loader = \
            self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy)

        device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
        eprint(type(self).__name__, "device_id: %d." % device_id)
        place = fluid.CUDAPlace(device_id)

        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        eprint(type(self).__name__, "run worker startup program done.")

        data_loader.set_sample_list_generator(train_reader, place)
        data_loader.start()
        print_to_err(type(self).__name__, "begin to train on trainer")
        out_losses = []
146 147 148

        main_program = fluid.default_main_program()
        lr_sheduler = self.get_lr_scheduler(main_program)
149
        for i in six.moves.xrange(RUN_STEP):
150
            loss = exe.run(main_program, fetch_list=[avg_cost])
151 152 153
            loss = loss[0] if loss else None
            out_losses.append(loss)
            print_to_err(type(self).__name__, "run step %d finished" % i)
154 155 156
            if lr_sheduler is not None:
                lr_sheduler.step()

157
        data_loader.reset()
158 159
        print_to_err(type(self).__name__, "trainer run finished")

T
tianshuo78520a 已提交
160
        sys.stdout.buffer.write(pickle.dumps(out_losses))
161

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
    def run_use_fleet_api_20_trainer(self, args):
        """
        1. remove codes for DistributedStrategy and leave the DistributedStrategy part to get_model()
        2. to run with fleet 2.0 api, set flags _use_fleet_api and _use_fleet_api_20 to True
        3. for now, not support test for model save
        """
        assert args.update_method == "nccl2" or "bkcl"

        self.lr = args.lr
        print_to_err("use_fleet 2.0", "fleet.node_num:")

        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
            self.get_model(batch_size=args.batch_size)

        if fluid.core.is_compiled_with_cuda():
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(device_id)
        elif fluid.core.is_compiled_with_xpu():
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
        else:
            raise ValueError(
                "fleet dygraph api must in paddlepaddle-xpu or paddlepaddle-gpu."
            )

        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        eprint(type(self).__name__, "run worker startup program done.")

        feed_var_list = [
192 193
            var for var in
            fluid.default_main_program().global_block().vars.values()
194 195 196 197 198 199 200 201 202 203 204 205 206
            if var.is_data
        ]

        eprint("feed_var_list:", feed_var_list)

        if feed_var_list[0].name == 'label':
            feed_var_list = feed_var_list[::-1]

        feeder = fluid.DataFeeder(feed_var_list, place)
        reader_generator = train_reader()

        def get_data():
            origin_batch = next(reader_generator)
X
xiongkun 已提交
207 208 209 210 211
            if paddle.distributed.get_world_size(
            ) == 1 and args.update_method == 'gloo':  # Gloo single mode
                return origin_batch

            elif args.update_method != "local" and args.use_reader_alloc:
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch

        print_to_err(type(self).__name__, "begin to train on trainer")
        out_losses = []
        for i in six.moves.xrange(RUN_STEP):
            loss, = exe.run(fluid.default_main_program(),
                            fetch_list=[avg_cost.name],
                            feed=feeder.feed(get_data()))
            out_losses.append(loss[0])
            print_to_err(type(self).__name__, "run step %d finished" % i)
        print_to_err(type(self).__name__, "trainer run finished")
        print_to_err(type(self).__name__, "dist losses: {}".format(out_losses))

T
tianshuo78520a 已提交
231
        sys.stdout.buffer.write(pickle.dumps(out_losses))
232

233 234
    def run_use_fleet_api_trainer(self, args):
        assert args.update_method == "nccl2" or "bkcl"
235 236 237 238 239 240 241 242

        self.lr = args.lr

        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.num_threads = 1

        dist_strategy = DistributedStrategy()
        dist_strategy.exec_strategy = exec_strategy
T
tangwei12 已提交
243
        dist_strategy.fuse_memory_size = 1  # MB
244
        dist_strategy.fuse_laryer_size = 1
245 246 247 248
        if args.use_local_sgd:
            dist_strategy.use_local_sgd = True
        if args.ut4grad_allreduce:
            dist_strategy._ut4grad_allreduce = True
249 250
        if args.sync_batch_norm:
            dist_strategy.sync_batch_norm = True
251 252 253

        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
        fleet.init(role)
254
        print_to_err("use_fleet", "fleet.node_num:")
T
tangwei12 已提交
255 256
        # "fleet.node_id:", fleet.node_id(),
        # "fleet.trainer_num:", fleet.worker_num())
257 258

        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
T
tangwei12 已提交
259
            self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy)
260 261 262 263

        trainer_prog = fleet._origin_program
        dist_prog = fleet.main_program

264 265 266 267 268 269 270 271 272 273
        if fluid.core.is_compiled_with_cuda():
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(device_id)
        elif fluid.core.is_compiled_with_xpu():
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
        else:
            raise ValueError(
                "fleet dygraph api must in paddlepaddle-xpu or paddlepaddle-gpu."
            )
274 275 276 277 278 279 280 281 282 283

        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        eprint(type(self).__name__, "run worker startup program done.")

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

284 285 286 287 288 289 290
        eprint("feed_var_list:", feed_var_list)

        # tmp add this code to pass python35 gcc8 CI
        # Fixme(gongweibao, wangxi), need fix fleet api program order
        if feed_var_list[0].name == 'label':
            feed_var_list = feed_var_list[::-1]

291 292 293 294 295 296 297 298 299 300 301 302 303 304
        feeder = fluid.DataFeeder(feed_var_list, place)
        reader_generator = train_reader()

        def get_data():
            origin_batch = next(reader_generator)
            if args.update_method != "local" and args.use_reader_alloc:
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch

305
        print_to_err(type(self).__name__, "begin to train on trainer")
306 307 308 309 310 311
        out_losses = []
        for i in six.moves.xrange(RUN_STEP):
            loss, = exe.run(dist_prog,
                            fetch_list=[avg_cost.name],
                            feed=feeder.feed(get_data()))
            out_losses.append(loss[0])
312 313
            print_to_err(type(self).__name__, "run step %d finished" % i)
        print_to_err(type(self).__name__, "trainer run finished")
314

T
tianshuo78520a 已提交
315
        sys.stdout.buffer.write(pickle.dumps(out_losses))
316

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
        if args.save_model:
            model_save_dir = "/tmp"
            if fleet.worker_index() == 0:
                model_save_dir_fluid = os.path.join(model_save_dir,
                                                    "fluid_persistables")
                model_save_dir_fleet = os.path.join(model_save_dir,
                                                    "fleet_persistables")
                infer_save_dir_fluid = os.path.join(model_save_dir,
                                                    "fluid_infer")
                infer_save_dir_fleet = os.path.join(model_save_dir,
                                                    "fleet_infer")
            else:
                model_save_dir_fluid = os.path.join(model_save_dir,
                                                    "fluid_persistables_2")
                model_save_dir_fleet = os.path.join(model_save_dir,
                                                    "fleet_persistables_2")
                infer_save_dir_fluid = os.path.join(model_save_dir,
                                                    "fluid_infer_2")
                infer_save_dir_fleet = os.path.join(model_save_dir,
                                                    "fleet_infer_2")
            fluid.io.save_persistables(exe, model_save_dir_fluid,
                                       fleet._origin_program)
            fleet.save_persistables(executor=exe, dirname=model_save_dir_fleet)
            feeded_var_names = [var.name for var in feed_var_list]
            fluid.io.save_inference_model(infer_save_dir_fluid,
                                          feeded_var_names, [avg_cost], exe,
                                          fleet._origin_program)
            fleet.save_inference_model(exe, infer_save_dir_fleet,
                                       feeded_var_names, [avg_cost])

347
    def run_trainer(self, args):
W
Wu Yi 已提交
348
        self.lr = args.lr
W
Wu Yi 已提交
349 350 351
        if args.nccl2_reduce_layer_local_run:
            test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
                self.get_model(batch_size=args.batch_size, single_device=True)
352 353 354
        elif args.use_dgc:
            test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
                self.get_model(batch_size=args.batch_size, use_dgc=args.use_dgc)
W
Wu Yi 已提交
355 356 357
        else:
            test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
                self.get_model(batch_size=args.batch_size)
358

W
Wu Yi 已提交
359
        if args.update_method == "pserver":
360
            print_to_err(
361 362
                type(self).__name__,
                "begin to run transpile on trainer with pserver mode")
363 364 365 366 367 368 369
            t = self.get_transpiler(trainer_id=args.trainer_id,
                                    main_program=fluid.default_main_program(),
                                    pserver_endpoints=args.endpoints,
                                    trainers=args.trainers,
                                    sync_mode=args.sync_mode,
                                    dc_asgd=args.dc_asgd,
                                    hogwild_mode=args.hogwild)
T
tangwei12 已提交
370

T
typhoonzero 已提交
371
            trainer_prog = t.get_trainer_program()
372
            print_to_err(
373 374
                type(self).__name__,
                "get trainer program done with pserver mode.")
W
Wu Yi 已提交
375
        elif args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer":
W
Wu Yi 已提交
376 377 378
            # transpile for nccl2
            config = fluid.DistributeTranspilerConfig()
            config.mode = "nccl2"
379
            config.nccl_comm_num = args.nccl_comm_num
380 381 382
            if args.use_hallreduce:
                config.use_hierarchical_allreduce = True
                config.hierarchical_allreduce_inter_nranks = args.hallreduce_inter_nranks
383
            print_to_err(
384 385
                type(self).__name__,
                "begin to run transpile on trainer with nccl2 mode")
W
Wu Yi 已提交
386
            nccl2_t = fluid.DistributeTranspiler(config=config)
387 388 389 390 391
            nccl2_t.transpile(args.trainer_id,
                              program=fluid.default_main_program(),
                              startup_program=fluid.default_startup_program(),
                              trainers=args.endpoints,
                              current_endpoint=args.current_endpoint)
392
            print_to_err(
393 394
                type(self).__name__,
                "get trainer program done. with nccl2 mode")
W
Wu Yi 已提交
395
            trainer_prog = fluid.default_main_program()
T
typhoonzero 已提交
396
        else:
397
            print_to_err(
398 399
                type(self).__name__,
                "do nothing about main program, just use it")
T
typhoonzero 已提交
400
            trainer_prog = fluid.default_main_program()
401
            print_to_err(type(self).__name__, "use main program done.")
T
typhoonzero 已提交
402

403 404 405
        # FIXME(gongwb):wait pserver initialization.
        time.sleep(1)

406
        if args.use_cuda:
407 408
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(device_id)
409 410 411
        else:
            place = fluid.CPUPlace()

412 413
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
414
        print_to_err(type(self).__name__, "run worker startup program done.")
T
typhoonzero 已提交
415

W
Wu Yi 已提交
416 417
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.num_threads = 1
418

W
Wu Yi 已提交
419
        build_stra = fluid.BuildStrategy()
420 421 422
        # FIXME force disable enable_inplace and memory_optimize
        build_stra.enable_inplace = False
        build_stra.memory_optimize = False
W
Wu Yi 已提交
423

424 425 426 427
        if args.fuse_all_reduce is not None:
            sys.stderr.write('fuse_all_reduce={}'.format(args.fuse_all_reduce))
            build_stra.fuse_all_reduce_ops = args.fuse_all_reduce

T
tangwei12 已提交
428 429 430
        if args.hogwild:
            build_stra.async_mode = True

431 432 433
        if args.enable_backward_deps:
            build_stra.enable_backward_optimizer_op_deps = True

W
Wu Yi 已提交
434 435 436 437 438
        if args.use_reduce:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
        else:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce

W
Wu Yi 已提交
439
        pass_builder = None
X
Xin Pan 已提交
440
        if args.batch_merge_repeat > 1:
X
fix  
Xin Pan 已提交
441
            pass_builder = build_stra._finalize_strategy_and_create_passes()
442
            mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass")
443
            mypass.set("num_repeats", args.batch_merge_repeat)
X
Xin Pan 已提交
444

W
Wu Yi 已提交
445
        if args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer":
446 447
            build_stra.num_trainers = len(args.endpoints.split(","))
            build_stra.trainer_id = args.trainer_id
W
Wu Yi 已提交
448
        else:
W
Wu Yi 已提交
449
            # case args.update_method == "nccl2_reduce_layer":
450 451
            build_stra.num_trainers = 1
            build_stra.trainer_id = 0
W
Wu Yi 已提交
452

453
        print_to_err(type(self).__name__, "begin to compile with data parallel")
X
Xin Pan 已提交
454
        binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
W
Wu Yi 已提交
455
            loss_name=avg_cost.name,
W
Wu Yi 已提交
456
            build_strategy=build_stra,
W
Wu Yi 已提交
457
            exec_strategy=exec_strategy)
458
        print_to_err(type(self).__name__, "program compiled with data parallel")
T
typhoonzero 已提交
459 460 461 462 463 464 465

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

        feeder = fluid.DataFeeder(feed_var_list, place)
466
        reader_generator = train_reader()
T
typhoonzero 已提交
467

468 469
        def get_data():
            origin_batch = next(reader_generator)
W
Wu Yi 已提交
470
            if args.update_method != "local" and args.use_reader_alloc:
471 472 473 474 475 476 477
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch
T
typhoonzero 已提交
478

479
        lr_scheduler = self.get_lr_scheduler(trainer_prog)
480
        print_to_err(type(self).__name__, "begin to train on trainer")
W
Wu Yi 已提交
481
        out_losses = []
482
        for i in six.moves.xrange(RUN_STEP):
483 484
            loss, = exe.run(binary,
                            fetch_list=[avg_cost.name],
485
                            feed=feeder.feed(get_data()))
W
Wu Yi 已提交
486
            out_losses.append(loss[0])
487
            print_to_err(type(self).__name__, "run step %d finished" % i)
488 489 490
            if lr_scheduler is not None:
                lr_scheduler.step()

491
        print_to_err(type(self).__name__, "trainer run finished")
492

493
        print_to_out(out_losses)
T
typhoonzero 已提交
494 495


496
class TestParallelDyGraphRunnerBase(object):
497

498 499 500 501 502 503 504 505
    def get_model(self):
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

    def run_one_loop(self, model, opt, data):
        raise NotImplementedError(
            "train_one_loop should be implemented by the child classes.")

506
    def _get_data(self, batch, args):
X
xiongkun 已提交
507 508 509 510
        if paddle.distributed.get_world_size(
        ) == 1 and args.update_method == 'gloo':  # Gloo single mode
            return batch
        elif args.update_method != "local":
511
            new_batch = []
512

513 514 515
            # NOTE(@xiongkun03) args.diff_batch means batch length is different:
            # such as : batch = [2,3,4,5], then the first rank will get [2]  and
            # the second rank will get [3,4,5].
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
            # this function is for test sparse_embedding_differ_length
            if hasattr(args, "diff_batch") and args.diff_batch:
                assert len(
                    batch) > 2, "in differ_batch mode, len(batch) must > 2."
                if paddle.distributed.get_rank() == 0:
                    new_batch.append(batch[0])
                elif paddle.distributed.get_rank() == 1:
                    new_batch.extend([_ for _ in batch[1:]])
                else:
                    raise NotImplementedError(
                        "Current TestParallelDyGraphRunnerBase don't support world_size > 2"
                    )
                return new_batch
            else:
                for offset, item in enumerate(batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
534 535 536
        else:
            return batch

537 538
    def run_trainer(self, args):
        seed = 90
X
xiongkun 已提交
539 540 541
        if args.update_method == 'gloo':
            place = fluid.CPUPlace()
        elif fluid.core.is_compiled_with_cuda():
542 543 544 545 546
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(device_id)
        elif fluid.core.is_compiled_with_xpu():
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
547 548 549
        elif fluid.core.is_compiled_with_npu():
            device_id = int(os.getenv("FLAGS_selected_npus", "0"))
            place = fluid.NPUPlace(device_id)
550 551 552
        elif fluid.core.is_compiled_with_mlu():
            device_id = int(os.getenv("FLAGS_selected_mlus", "0"))
            place = fluid.MLUPlace(device_id)
553
        else:
X
xiongkun 已提交
554
            assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.")
555 556 557 558

        with fluid.dygraph.guard(place):
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed
Y
Yan Xu 已提交
559 560
            np.random.seed(seed)
            import random
561
            random.seed(seed)
562 563
            model, train_reader, opt = self.get_model()
            nranks = len(args.endpoints.split(",")) if args.endpoints else 1
Y
Yan Xu 已提交
564

565
            #if args.update_method == "nccl2":
566
            if args.update_method == "nccl2" or args.update_method == "bkcl" or args.update_method == "hccl" or args.update_method == "cncl":
567 568 569 570 571
                strategy = dygraph.parallel.ParallelStrategy()
                strategy.nranks = nranks
                strategy.local_rank = args.trainer_id
                strategy.trainer_endpoints = args.endpoints.split(",")
                strategy.current_endpoint = args.current_endpoint
572
                paddle.distributed.init_parallel_env()
573
                print_to_err(
574 575
                    type(self).__name__,
                    "begin to prepare context in dygraph with nccl2")
576
                dygraph.parallel.prepare_context(strategy)
577 578 579 580 581 582
                if not args.find_unused_parameters:
                    model = dygraph.parallel.DataParallel(
                        model, strategy, find_unused_parameters=False)
                else:
                    model = dygraph.parallel.DataParallel(
                        model, strategy, find_unused_parameters=True)
583
                print_to_err(type(self).__name__, "model built in dygraph")
X
xiongkun 已提交
584 585 586 587 588 589 590 591 592 593

            elif args.update_method == "gloo":
                paddle.distributed.init_parallel_env()
                if not args.find_unused_parameters:
                    model = dygraph.parallel.DataParallel(
                        model, find_unused_parameters=False)
                else:
                    model = dygraph.parallel.DataParallel(
                        model, find_unused_parameters=True)

594
            out_losses = []
595
            print_to_err(type(self).__name__, "begin to run dygraph training")
596
            for step_id, data in enumerate(train_reader()):
597
                data = self._get_data(data, args)
598 599 600
                if step_id == RUN_STEP:
                    break
                loss = self.run_one_loop(model, opt, data)
G
guru4elephant 已提交
601
                if step_id % 10 == 0:
602
                    print_to_err(
603
                        type(self).__name__,
604
                        "loss at step %d: %f" % (step_id, loss.numpy()))
Y
Yan Xu 已提交
605
                out_losses.append(loss.numpy())
606 607 608 609

                loss.backward()

                opt.minimize(loss)
610 611
                if not args.accumulate_gradient:
                    model.clear_gradients()
612
        print_to_out(out_losses)
613

614 615 616 617 618 619 620 621 622
    def run_trainer_with_spawn(self, args):
        # 1. enable dygraph
        paddle.disable_static()

        # 2. init seed
        seed = 90
        paddle.static.default_startup_program().random_seed = seed
        paddle.static.default_main_program().random_seed = seed
        np.random.seed(seed)
623
        random.seed(seed)
624
        # get trainer id
L
LiYuRio 已提交
625 626
        paddle.distributed.parallel._get_global_parallel_env()
        args.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
627 628

        # 3. init parallel env
X
xiongkun 已提交
629
        if args.update_method in ["nccl2", "gloo"]:
630 631 632 633
            paddle.distributed.init_parallel_env()

        # 4. train model
        model, train_reader, opt = self.get_model()
X
xiongkun 已提交
634
        if args.update_method in ["nccl2", "gloo"]:
635 636
            model = paddle.DataParallel(
                model, find_unused_parameters=args.find_unused_parameters)
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651

        out_losses = []
        for step_id, data in enumerate(train_reader()):
            data = self._get_data(data, args)
            if step_id == RUN_STEP:
                break
            loss = self.run_one_loop(model, opt, data)
            out_losses.append(loss.numpy())

            loss.backward()

            opt.minimize(loss)
            model.clear_gradients()
        return out_losses

652
    def run_use_fleet_api_trainer(self, args):
653 654 655 656 657 658 659 660 661
        import paddle.distributed.fleet as fleet
        # 1. enable dygraph
        paddle.disable_static()

        # 2. init seed
        seed = 90
        paddle.static.default_startup_program().random_seed = seed
        paddle.static.default_main_program().random_seed = seed
        np.random.seed(seed)
662
        random.seed(seed)
663
        # get trainer id
L
LiYuRio 已提交
664 665
        paddle.distributed.parallel._get_global_parallel_env()
        args.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
666

667 668
        # set strategy
        strategy = fleet.DistributedStrategy()
669 670
        if args.find_unused_parameters:
            strategy.find_unused_parameters = True
671

672
        # 3. init parallel env
673
        if args.update_method == "nccl2" or "bkcl" or "hccl":
674
            fleet.init(is_collective=True, strategy=strategy)
675 676 677

        # 4. train model
        model, train_reader, opt = self.get_model()
678
        if args.update_method == "nccl2" or "bkcl" or "hccl":
679 680 681 682 683 684 685 686 687 688 689 690 691 692
            opt = fleet.distributed_optimizer(opt)
            model = fleet.distributed_model(model)

        out_losses = []
        for step_id, data in enumerate(train_reader()):
            data = self._get_data(data, args)
            if step_id == RUN_STEP:
                break
            loss = self.run_one_loop(model, opt, data)
            out_losses.append(loss.numpy())

            loss.backward()

            opt.step()
693 694
            if not args.accumulate_gradient:
                opt.clear_grad()
695 696
        print_to_out(out_losses)

697

T
typhoonzero 已提交
698
def runtime_main(test_class):
W
Wu Yi 已提交
699
    parser = argparse.ArgumentParser(description='Run dist test.')
700 701 702 703
    parser.add_argument('--role',
                        type=str,
                        required=True,
                        choices=['pserver', 'trainer'])
W
Wu Yi 已提交
704
    parser.add_argument('--endpoints', type=str, required=False, default="")
705 706 707 708 709
    parser.add_argument('--update_method',
                        type=str,
                        default="local",
                        choices=[
                            "pserver", "nccl2", "bkcl", "local",
710
                            "nccl2_reduce_layer", "gloo", "hccl", "cncl"
711
                        ])
W
Wu Yi 已提交
712 713
    parser.add_argument('--trainer_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
714
    parser.add_argument('--nccl_comm_num', type=int, required=False, default=1)
715 716
    parser.add_argument('--enable_backward_deps', action='store_true')
    parser.add_argument('--use_hallreduce', action='store_true')
717
    parser.add_argument('--use_pipeline', action='store_true')
718
    parser.add_argument('--use_fleet_api', action='store_true')
719
    parser.add_argument('--use_fleet_api_20', action='store_true')
720
    parser.add_argument('--use_local_sgd', action='store_true')
721
    parser.add_argument('--diff_batch', action='store_true')
722
    parser.add_argument('--ut4grad_allreduce', action='store_true')
723 724 725 726 727 728 729 730
    parser.add_argument('--hallreduce_inter_nranks',
                        type=int,
                        required=False,
                        default=2)
    parser.add_argument('--current_endpoint',
                        type=str,
                        required=False,
                        default="")
W
Wu Yi 已提交
731
    parser.add_argument('--sync_mode', action='store_true')
732
    parser.add_argument('--use_cuda', action='store_true')
X
xiongkun 已提交
733
    parser.add_argument('--use_cpu', action='store_true')
734
    parser.add_argument('--use_xpu', action='store_true')
735
    parser.add_argument('--use_dgc', action='store_true')
736
    parser.add_argument('--use_npu', action='store_true')
737
    parser.add_argument('--use_mlu', action='store_true')
738
    parser.add_argument('--accumulate_gradient', action='store_true')
739
    parser.add_argument('--find_unused_parameters', action='store_true')
W
Wu Yi 已提交
740
    parser.add_argument('--use_reduce', action='store_true')
W
Wu Yi 已提交
741
    parser.add_argument('--dc_asgd', action='store_true')
T
tangwei12 已提交
742
    parser.add_argument('--hogwild', action='store_true')
743
    parser.add_argument('--save_model', action='store_true')
744 745 746
    parser.add_argument('--use_reader_alloc',
                        action='store_true',
                        required=False)
747
    parser.add_argument('--batch_size', required=False, type=int, default=2)
W
Wu Yi 已提交
748
    parser.add_argument('--lr', required=False, type=float, default=0.001)
749 750 751 752 753 754 755 756
    parser.add_argument('--batch_merge_repeat',
                        required=False,
                        type=int,
                        default=1)
    parser.add_argument('--nccl2_reduce_layer_local_run',
                        required=False,
                        type=bool,
                        default=False)
757
    parser.add_argument('--sync_batch_norm', action='store_true')
758 759 760 761
    parser.add_argument('--fuse_all_reduce',
                        required=False,
                        type=ast.literal_eval,
                        default=None)
W
Wu Yi 已提交
762 763

    args = parser.parse_args()
T
typhoonzero 已提交
764

X
xiongkun 已提交
765 766 767
    if args.update_method == 'gloo':
        paddle.set_device("cpu")

T
typhoonzero 已提交
768
    model = test_class()
W
Wu Yi 已提交
769
    if args.role == "pserver" and args.update_method == "pserver":
W
Wu Yi 已提交
770
        model.run_pserver(args)
771 772
    elif args.use_fleet_api:
        model.run_use_fleet_api_trainer(args)
773 774
    elif args.use_fleet_api_20:
        model.run_use_fleet_api_20_trainer(args)
775 776
    elif args.use_pipeline:
        model.run_pipeline_trainer(args)
T
typhoonzero 已提交
777
    else:
778
        model.run_trainer(args)
X
Xin Pan 已提交
779

M
minqiyang 已提交
780

Y
Yancey1989 已提交
781 782
import socket
from contextlib import closing
M
minqiyang 已提交
783

X
Xin Pan 已提交
784 785

class TestDistBase(unittest.TestCase):
786

W
Wu Yi 已提交
787 788 789
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

790 791 792
    def _after_setup_config(self):
        if self._enforce_place == "CPU":
            self.__use_cuda = False
793
            self.__use_xpu = False
794
            self._use_dgc = False
795
            self.__use_npu = False
796
            self._use_mlu = False
797 798
        elif self._enforce_place == "GPU":
            self.__use_cuda = True
799
            self.__use_xpu = False
800
            self.__use_npu = False
801
            self._use_mlu = False
802 803 804 805
        elif self._enforce_place == "XPU":
            self.__use_cuda = False
            self.__use_xpu = True
            self._use_dgc = False
806
            self.__use_npu = False
807
            self._use_mlu = False
808 809 810 811 812
        elif self._enforce_place == "NPU":
            self.__use_cuda = False
            self.__use_xpu = False
            self._use_dgc = False
            self.__use_npu = True
813 814 815 816 817 818 819
            self._use_mlu = False
        elif self._enforce_place == "MLU":
            self.__use_cuda = False
            self.__use_xpu = False
            self._use_dgc = False
            self.__use_npu = False
            self._use_mlu = True
820 821 822 823 824
        else:
            if fluid.core.is_compiled_with_cuda():
                self.__use_cuda = True
            else:
                self.__use_cuda = False
825 826 827 828
                self._use_dgc = False

        if self._use_reduce:
            assert not self._use_dgc
829

X
Xin Pan 已提交
830 831 832
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
Y
Yancey1989 已提交
833
        self._port_set = set()
M
minqiyang 已提交
834
        self._python_interp = sys.executable
W
Wu Yi 已提交
835
        self._sync_mode = True
T
tangwei12 已提交
836
        self._hogwild_mode = False
837
        self._enforce_place = None
W
Wu Yi 已提交
838
        self._use_reduce = False
W
Wu Yi 已提交
839
        self._dc_asgd = False  # must use with async mode
840
        self._use_reader_alloc = True
W
Wu Yi 已提交
841
        self._nccl2_mode = False
842
        self._bkcl_mode = False
X
xiongkun 已提交
843
        self._gloo_mode = False  # now, support gloo backend
844
        self._hccl_mode = False
845
        self._cncl_mode = False
846
        self._pipeline_mode = False
847
        self._mp_mode = False
848
        self._diff_batch = False
W
Wu Yi 已提交
849 850 851 852 853
        # FIXME(typhoonzero): I added this stupid argument to enable
        # testing allreduce layers, which users can call layers.allreduce
        # to accumulate tensors at anywhere. Find a better way to do this
        # test, reduce check this argument everywhere.
        self._nccl2_reduce_layer = False
W
Wu Yi 已提交
854
        self._lr = 0.001
855
        self._use_dgc = False
856
        self._dygraph = False
857
        self._nccl_comm_num = 1
858
        self._enable_backward_deps = False
859
        self._use_fleet_api = False
860
        self._use_fleet_api_20 = False
861 862
        self._use_local_sgd = False
        self._ut4grad_allreduce = False
863
        self._use_hallreduce = False
864
        self._save_model = False
865
        self._fuse_all_reduce = None
866
        self._accumulate_gradient = False
867
        self._find_unused_parameters = False
W
Wu Yi 已提交
868
        self._setup_config()
869 870 871 872 873 874 875 876 877 878 879 880

        global DIST_UT_PORT
        if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
            DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))

        if DIST_UT_PORT == 0:
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
                self._find_free_port(), self._find_free_port())
        else:
            self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
                DIST_UT_PORT, DIST_UT_PORT + 1)
            DIST_UT_PORT += 2
881
            self._dist_port = DIST_UT_PORT
882

883
        self._after_setup_config()
X
Xin Pan 已提交
884

885 886 887 888 889
        self.temp_dir = tempfile.TemporaryDirectory()

    def tearDown(self):
        self.temp_dir.cleanup()

Y
Yancey1989 已提交
890
    def _find_free_port(self):
891

Y
Yancey1989 已提交
892 893 894 895
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
896
                print_to_err(
897
                    type(self).__name__, "socket name: %s" % s.getsockname()[1])
Y
Yancey1989 已提交
898 899 900 901 902 903 904
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port
Y
Yancey1989 已提交
905

906 907 908 909 910
    def start_pserver(self,
                      model_file,
                      check_error_log,
                      required_envs,
                      log_name=""):
X
Xin Pan 已提交
911
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
912 913 914 915 916 917 918 919
        ps_cmd = "%s"

        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            required_envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
            ps_cmd += " -m coverage run --branch -p"

        ps_cmd += " %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"

W
Wu Yi 已提交
920
        ps0_cmd = ps_cmd % \
921 922
                  (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                   self._trainers)
W
Wu Yi 已提交
923
        ps1_cmd = ps_cmd % \
924 925
                  (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
                   self._trainers)
W
Wu Yi 已提交
926 927 928 929

        if self._sync_mode:
            ps0_cmd += " --sync_mode"
            ps1_cmd += " --sync_mode"
X
Xin Pan 已提交
930

931 932
        print(ps0_cmd)
        print(ps1_cmd)
933 934 935 936
        path0 = os.path.join(self.temp_dir.name, log_name + "_ps0_err.log")
        path1 = os.path.join(self.temp_dir.name, log_name + "_ps1_err.log")
        ps0_pipe = open(path0, "wb")
        ps1_pipe = open(path1, "wb")
G
gongweibao 已提交
937

938
        print_to_err(type(self).__name__, "going to start pserver process 0")
939 940 941 942
        ps0_proc = subprocess.Popen(ps0_cmd.strip().split(" "),
                                    stdout=subprocess.PIPE,
                                    stderr=ps0_pipe,
                                    env=required_envs)
943
        print_to_err(type(self).__name__, "going to start pserver process 1")
944 945 946 947
        ps1_proc = subprocess.Popen(ps1_cmd.strip().split(" "),
                                    stdout=subprocess.PIPE,
                                    stderr=ps1_pipe,
                                    env=required_envs)
G
gongweibao 已提交
948

949
        return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
X
Xin Pan 已提交
950

951 952 953 954 955
    def _run_local(self,
                   model,
                   envs,
                   check_error_log=False,
                   batch_size=DEFAULT_BATCH_SIZE,
956
                   batch_merge_repeat=1,
957
                   log_name="",
X
xiongkun 已提交
958
                   devices="1"):
G
gongweibao 已提交
959

960 961 962 963 964 965
        cmd = self._python_interp

        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
            cmd += " -m coverage run --branch -p"

966 967
        cmd += " %s --role trainer --update_method local --lr %f" % (model,
                                                                     self._lr)
968

969 970 971 972
        if batch_size != DEFAULT_BATCH_SIZE:
            cmd += " --batch_size %d" % batch_size
        if batch_merge_repeat > 1:
            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
W
Wu Yi 已提交
973 974
        if self._nccl2_reduce_layer:
            cmd += " --nccl2_reduce_layer_local_run 1"
975

976
        if self.__use_cuda:
977
            cmd += " --use_cuda"
W
Wu Yi 已提交
978
            env_local = {
979 980 981 982 983 984 985 986
                "CUDA_VISIBLE_DEVICES": devices,
                "PADDLE_TRAINERS_NUM": "1",
                "PADDLE_TRAINER_ID": "0"
            }
        elif self.__use_xpu:
            cmd += " --use_xpu"
            env_local = {
                "FLAGS_selected_xpus": devices,
W
Wu Yi 已提交
987 988 989
                "PADDLE_TRAINERS_NUM": "1",
                "PADDLE_TRAINER_ID": "0"
            }
990 991 992 993 994 995 996
        elif self.__use_npu:
            cmd += " --use_npu"
            env_local = {
                "FLAGS_selected_npus": devices,
                "PADDLE_TRAINERS_NUM": "1",
                "PADDLE_TRAINER_ID": "0"
            }
997 998 999
        else:
            env_local = {'CPU_NUM': '1'}

1000
        # not use dgc in single card
1001
        if len(devices) > 1 and self._use_dgc:
1002 1003
            cmd += " --use_dgc"

1004 1005 1006
        if self._accumulate_gradient:
            cmd += " --accumulate_gradient"

1007 1008 1009
        if self._find_unused_parameters:
            cmd += " --find_unused_parameters"

W
Wu Yi 已提交
1010 1011
        env_local.update(envs)
        print("local_cmd: {}, env: {}".format(cmd, env_local))
G
gongweibao 已提交
1012

1013
        if check_error_log:
1014 1015
            path = os.path.join(self.temp_dir.name, log_name + "_local.log")
            err_log = open(path, "wb")
1016 1017 1018 1019
            local_proc = subprocess.Popen(cmd.split(" "),
                                          stdout=subprocess.PIPE,
                                          stderr=err_log,
                                          env=env_local)
G
gongweibao 已提交
1020
        else:
1021 1022 1023 1024
            local_proc = subprocess.Popen(cmd.split(" "),
                                          stdout=subprocess.PIPE,
                                          stderr=subprocess.PIPE,
                                          env=env_local)
G
gongweibao 已提交
1025

1026 1027 1028 1029 1030 1031
        local_out, local_err = local_proc.communicate()

        if check_error_log:
            err_log.close()

        sys.stderr.write('local_stderr: %s\n' % local_err)
W
Wu Yi 已提交
1032
        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
X
Xin Pan 已提交
1033

W
Wu Yi 已提交
1034
        return pickle.loads(local_out)
1035

X
xiongkun 已提交
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
    def _run_local_gloo(self,
                        model,
                        envs,
                        check_error_log=False,
                        batch_size=DEFAULT_BATCH_SIZE,
                        batch_merge_repeat=1,
                        log_name="",
                        devices="0"):
        saved_endpoints = self._ps_endpoints
        self._ps_endpoints = self._ps_endpoints.split(',')[0]
        result = self._run_cluster_gloo(model, envs, 'gloo', check_error_log,
                                        log_name)
        self._ps_endpoints = saved_endpoints
        return result

1051
    def _run_cluster(self, model, envs, check_error_log, log_name):
X
Xin Pan 已提交
1052
        # Run dist train to compare with local results
1053 1054 1055 1056
        ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
                                                          check_error_log,
                                                          envs,
                                                          log_name=log_name)
W
Wu Yi 已提交
1057

X
Xin Pan 已提交
1058
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
1059

1060 1061 1062 1063 1064 1065 1066 1067
        tr_cmd = "%s"

        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
            tr_cmd += " -m coverage run --branch -p"

        tr_cmd += " %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f"

W
Wu Yi 已提交
1068
        tr0_cmd = tr_cmd % \
1069
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
1070
                   0, ps0_ep, self._trainers, self._lr)
W
Wu Yi 已提交
1071
        tr1_cmd = tr_cmd % \
1072
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
1073
                   1, ps1_ep, self._trainers, self._lr)
W
Wu Yi 已提交
1074 1075 1076 1077

        if self._sync_mode:
            tr0_cmd += " --sync_mode"
            tr1_cmd += " --sync_mode"
T
tangwei12 已提交
1078 1079 1080
        if self._hogwild_mode:
            tr0_cmd += " --hogwild"
            tr1_cmd += " --hogwild"
W
Wu Yi 已提交
1081 1082 1083
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
1084 1085 1086
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
1087
        if self.__use_cuda:
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)
X
Xin Pan 已提交
1098

W
Wu Yi 已提交
1099 1100
        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
1101 1102 1103 1104 1105

        path0 = os.path.join(self.temp_dir.name, log_name + "_tr0_err.log")
        path1 = os.path.join(self.temp_dir.name, log_name + "_tr1_err.log")
        tr0_pipe = open(path0, "wb")
        tr1_pipe = open(path1, "wb")
G
gongweibao 已提交
1106

1107
        print_to_err(type(self).__name__, "going to start trainer process 0")
1108 1109 1110 1111
        tr0_proc = subprocess.Popen(tr0_cmd.strip().split(" "),
                                    stdout=subprocess.PIPE,
                                    stderr=tr0_pipe,
                                    env=env0)
1112
        print_to_err(type(self).__name__, "going to start trainer process 1")
1113 1114 1115 1116
        tr1_proc = subprocess.Popen(tr1_cmd.strip().split(" "),
                                    stdout=subprocess.PIPE,
                                    stderr=tr1_pipe,
                                    env=env1)
X
Xin Pan 已提交
1117

1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
        # Wait until trainer process terminate
        while True:
            stat0 = tr0_proc.poll()
            time.sleep(0.1)
            if stat0 is not None:
                break
        while True:
            stat1 = tr1_proc.poll()
            time.sleep(0.1)
            if stat1 is not None:
                break

1130 1131
        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
X
Xin Pan 已提交
1132

G
gongweibao 已提交
1133
        # close trainer file
1134 1135 1136 1137
        tr0_pipe.close()
        tr1_pipe.close()
        ps0_pipe.close()
        ps1_pipe.close()
W
Wu Yi 已提交
1138

W
Wu Yi 已提交
1139 1140
        ps0.terminate()
        ps1.terminate()
T
typhoonzero 已提交
1141

W
Wu Yi 已提交
1142 1143
        return pickle.loads(tr0_out), pickle.loads(tr1_out)

X
xiongkun 已提交
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165
    def _get_gloo_trainer_cmd(self, model, ep, update_method, trainer_id,
                              trainer_num):
        env = {}
        tr_cmd = "%s -u"

        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            tr_cmd += " -m coverage run --branch -p"

        tr_cmd += " %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f"

        tr_cmd = tr_cmd % \
                 (self._python_interp, model, self._ps_endpoints,
                  trainer_id, ep, update_method, self._lr)

        if self._use_reduce:
            tr_cmd += " --use_reduce"
        if self._use_reader_alloc:
            tr_cmd += " --use_reader_alloc"
        #assert self._use_reduce == False, "gloo not support _use_reduce"
        #assert self._use_reader_alloc == False, "gloo not support _use_reduce"
        if self._save_model:
            tr_cmd += " --save_model"
1166 1167
        if self._diff_batch:
            tr_cmd += " --diff_batch"
X
xiongkun 已提交
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183
        self.__use_cuda = False
        self.__use_xpu = False
        assert self.__use_cuda == False, "gloo not support use cuda"
        assert self.__use_xpu == False, "gloo not support use xpu"
        tr_cmd += " --use_cpu"
        env.update({
            "PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
            "PADDLE_TRAINER_ID": "{}".format(trainer_id),
            "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
            "PADDLE_CURRENT_ENDPOINT": ep,
            "PADDLE_CURRENT_ENDPOINT": ep,
            "PADDLE_DISTRI_BACKEND": "gloo",
            "GLOG_v": "2",
        })

        assert self._use_dgc == False, "gloo not support use dgc"
1184

X
xiongkun 已提交
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202
        if self._accumulate_gradient:
            tr_cmd += " --accumulate_gradient"

        if self._find_unused_parameters:
            tr_cmd += " --find_unused_parameters"

        assert self._pipeline_mode == False, "gloo not support use pipeline"

        if self._enable_backward_deps:  # build strategy, save it
            tr_cmd += " --enable_backward_deps"

        if self._fuse_all_reduce is not None:
            tr_cmd += " --fuse_all_reduce {}".format(self._fuse_all_reduce)

        assert self._use_fleet_api == False, "gloo not support use fleet api"
        assert self._use_fleet_api_20 == False, "gloo not support use fleet api"
        return tr_cmd, env

1203 1204 1205
    def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id,
                               trainer_num):
        env = {}
1206 1207 1208 1209 1210 1211 1212
        tr_cmd = "%s -u"

        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            tr_cmd += " -m coverage run --branch -p"

        tr_cmd += " %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f"

1213
        tr_cmd = tr_cmd % \
T
tangwei12 已提交
1214 1215
                 (self._python_interp, model, self._ps_endpoints,
                  trainer_id, ep, update_method, self._lr)
W
Wu Yi 已提交
1216 1217

        if self._use_reduce:
1218
            tr_cmd += " --use_reduce"
W
Wu Yi 已提交
1219
        if self._use_reader_alloc:
1220
            tr_cmd += " --use_reader_alloc"
1221 1222
        if self._save_model:
            tr_cmd += " --save_model"
W
Wu Yi 已提交
1223
        if self.__use_cuda:
1224 1225
            tr_cmd += " --use_cuda"
            env.update({
1226
                "FLAGS_selected_gpus": "{}".format(0),
W
WangXi 已提交
1227
                "CUDA_VISIBLE_DEVICES": "{}".format(trainer_id),
1228
                "PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
1229 1230 1231
                "PADDLE_TRAINER_ID": "{}".format(trainer_id),
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
                "PADDLE_CURRENT_ENDPOINT": ep,
1232
            })
1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
        # TODO(liuyuhui):XPU_VISIBLE_DEVICES is not working right now,
        # will update it after Badiu Kunlun partners' support.
        elif self.__use_xpu:
            tr_cmd += " --use_xpu"
            env.update({
                "FLAGS_selected_xpus": "{}".format(trainer_id),
                #"XPU_VISIBLE_DEVICES": "{}".format(trainer_id + 1),
                "PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
                "PADDLE_TRAINER_ID": "{}".format(trainer_id),
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
                "PADDLE_CURRENT_ENDPOINT": ep,
                "GLOG_v": "2",
            })
1246 1247 1248 1249 1250 1251 1252 1253 1254 1255
        elif self.__use_npu:
            tr_cmd += " --use_npu"
            env.update({
                "FLAGS_selected_npus": "{}".format(trainer_id),
                "PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
                "PADDLE_TRAINER_ID": "{}".format(trainer_id),
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
                "PADDLE_CURRENT_ENDPOINT": ep,
                "GLOG_v": "2",
            })
1256 1257 1258 1259 1260 1261 1262 1263 1264 1265
        elif self._use_mlu:
            tr_cmd += " --use_mlu"
            env.update({
                "FLAGS_selected_mlus": "{}".format(trainer_id),
                "PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
                "PADDLE_TRAINER_ID": "{}".format(trainer_id),
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
                "PADDLE_CURRENT_ENDPOINT": ep,
                "GLOG_v": "4",
            })
W
Wu Yi 已提交
1266
        else:
1267
            env.update({'CPU_NUM': '1'})
W
Wu Yi 已提交
1268

1269
        if self._use_dgc:
1270 1271
            tr_cmd += " --use_dgc"

1272 1273 1274
        if self._accumulate_gradient:
            tr_cmd += " --accumulate_gradient"

1275 1276 1277
        if self._find_unused_parameters:
            tr_cmd += " --find_unused_parameters"

1278 1279
        if self._pipeline_mode:
            tr_cmd += " --use_pipeline"
1280
        if self._mp_mode:
W
WangXi 已提交
1281
            env = {"FLAGS_selected_gpus": "{}".format(trainer_id)}
1282 1283

        if self._nccl_comm_num > 1:
1284
            tr_cmd += " --nccl_comm_num {}".format(self._nccl_comm_num)
1285

1286 1287
        if self._use_hallreduce:
            tr_cmd += " --use_hallreduce --hallreduce_inter_nranks 2"
1288

1289
        if self._enable_backward_deps:
1290
            tr_cmd += " --enable_backward_deps"
1291

1292 1293 1294
        if self._fuse_all_reduce is not None:
            tr_cmd += " --fuse_all_reduce {}".format(self._fuse_all_reduce)

1295
        if self._use_fleet_api:
1296
            tr_cmd += " --use_fleet_api_20" if self._use_fleet_api_20 else " --use_fleet_api"
1297 1298 1299 1300
            if self._use_local_sgd:
                tr_cmd += " --use_local_sgd"
            if self._ut4grad_allreduce:
                tr_cmd += " --ut4grad_allreduce"
1301 1302
            if hasattr(self, '_sync_batch_norm') and self._sync_batch_norm:
                tr_cmd += " --sync_batch_norm"
1303

1304 1305 1306
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')

1307
        return tr_cmd, env
W
Wu Yi 已提交
1308

X
xiongkun 已提交
1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320
    def _run_cluster_gloo(self, model, envs, update_method, check_error_log,
                          log_name):
        assert update_method == "gloo", "_run_cluster_gloo must have update_method: gloo, but get %s" % update_method
        assert not self._use_hallreduce, "_run_cluster_gloo must have _use_hallreduce = false"

        worker_endpoints = self._ps_endpoints.split(",")

        trainer_num = len(worker_endpoints)

        procs = []
        pipes = []
        for i in range(0, trainer_num):
1321 1322 1323 1324
            tr_cmd, tr_env = self._get_gloo_trainer_cmd(model,
                                                        worker_endpoints[i],
                                                        update_method, i,
                                                        trainer_num)
X
xiongkun 已提交
1325 1326 1327 1328 1329 1330
            tr_env.update(envs)
            tr_env["GLOG_vmodule"] = 'gloo_context=4'
            tr_env["GLOG_v"] = '3'
            print("use_hallreduce:{} tr_cmd:{}, env: {}".format(
                self._use_hallreduce, tr_cmd, tr_env))

1331 1332 1333
            path = os.path.join(self.temp_dir.name,
                                log_name + "_tr{}_err.log".format(i))
            tr_pipe = open(path, "wb")
X
xiongkun 已提交
1334 1335 1336 1337

            print_to_err(
                type(self).__name__,
                "going to start process {} with nccl2".format(i))
1338 1339 1340 1341
            tr_proc = subprocess.Popen(tr_cmd.strip().split(" "),
                                       stdout=subprocess.PIPE,
                                       stderr=tr_pipe,
                                       env=tr_env)
X
xiongkun 已提交
1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362

            procs.append(tr_proc)
            pipes.append(tr_pipe)

        outs = []
        for i in range(0, trainer_num):
            tr_out, tr_err = procs[i].communicate()
            outs.append(tr_out)
            pipes[i].close()
            sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err))

        if trainer_num == 1:
            if check_error_log: print("outs[0]:", outs[0])
            return pickle.loads(outs[0])

        else:
            if check_error_log:
                print("outs[0]:", outs[0])
                print("outs[1]:", outs[1])
            return pickle.loads(outs[0]), pickle.loads(outs[1])

1363 1364
    def _run_cluster_nccl2(self, model, envs, update_method, check_error_log,
                           log_name):
1365 1366
        if self._use_hallreduce:
            self._ps_endpoints = ""
1367 1368 1369

            global DIST_UT_PORT
            if DIST_UT_PORT == 0:
W
WangXi 已提交
1370
                # NOTE(wangxi). hallreduce test must use 4cards after nccl>=2.7
1371 1372 1373 1374 1375 1376 1377
                for i in range(0, 4):
                    self._ps_endpoints += "127.0.0.1:%s," % (
                        self._find_free_port())
            else:
                for i in range(0, 4):
                    self._ps_endpoints += "127.0.0.1:%s," % (DIST_UT_PORT + i)
                DIST_UT_PORT += 4
1378
            self._ps_endpoints = self._ps_endpoints[:-1]
W
Wu Yi 已提交
1379

1380 1381
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
W
Wu Yi 已提交
1382

1383
        trainer_num = len(worker_endpoints)
W
Wu Yi 已提交
1384

1385 1386 1387 1388 1389 1390 1391 1392
        procs = []
        pipes = []
        for i in range(0, trainer_num):
            tr_cmd, tr_env = self._get_nccl2_trainer_cmd(
                model, worker_endpoints[i], update_method, i, trainer_num)
            tr_env.update(envs)
            print("use_hallreduce:{} tr_cmd:{}, env: {}".format(
                self._use_hallreduce, tr_cmd, tr_env))
W
Wu Yi 已提交
1393

1394 1395 1396
            path = os.path.join(self.temp_dir.name,
                                log_name + "_tr{}_err.log".format(i))
            tr_pipe = open(path, "wb")
W
Wu Yi 已提交
1397

1398
            print_to_err(
1399 1400
                type(self).__name__,
                "going to start process {} with nccl2".format(i))
1401 1402 1403 1404
            tr_proc = subprocess.Popen(tr_cmd.strip().split(" "),
                                       stdout=subprocess.PIPE,
                                       stderr=tr_pipe,
                                       env=tr_env)
1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415

            procs.append(tr_proc)
            pipes.append(tr_pipe)

        outs = []
        for i in range(0, trainer_num):
            tr_out, tr_err = procs[i].communicate()
            outs.append(tr_out)
            pipes[i].close()
            sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err))

1416 1417 1418
        if check_error_log:
            print("outs[0]:", outs[0])
            print("outs[1]:", outs[1])
1419

1420
        return pickle.loads(outs[0]), pickle.loads(outs[1])
1421

1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440
    def _run_pipeline(self, model, envs, check_error_log, log_name):
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
        update_method = "nccl2"

        trainer_num = len(worker_endpoints)

        procs = []
        pipes = []
        for i in range(0, trainer_num):
            tr_cmd, tr_env = self._get_nccl2_trainer_cmd(
                model, worker_endpoints[i], update_method, i, trainer_num)
            tr_env.update(envs)
            tr_env['CUDA_VISIBLE_DEVICES'] = "0,1"
            tr_env['NCCL_SHM_DISABLE'] = '1'
            tr_env['FLAGS_selected_gpus'] = str(i)
            tr_env['FLAGS_cudnn_deterministic'] = '0'
            print("tr_cmd:{}, env: {}".format(tr_cmd, tr_env))

1441 1442
            path = os.path.join(self.temp_dir.name + "tr{}_err.log".format(i))
            tr_pipe = open(path, "wb")
1443 1444 1445 1446

            print_to_err(
                type(self).__name__,
                "going to start process {} with nccl2".format(i))
1447 1448 1449 1450
            tr_proc = subprocess.Popen(tr_cmd.strip().split(" "),
                                       stdout=subprocess.PIPE,
                                       stderr=tr_pipe,
                                       env=tr_env)
1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466

            procs.append(tr_proc)
            pipes.append(tr_pipe)

        outs = []
        for i in range(0, trainer_num):
            tr_out, tr_err = procs[i].communicate()
            outs.append(tr_out)
            pipes[i].close()
            sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err))

        if check_error_log:
            print("outs[0]:", outs[0])
            print("outs[1]:", outs[1])
        return pickle.loads(outs[0]), pickle.loads(outs[1])

1467
    def _get_required_envs(self, check_error_log=False, need_envs={}):
1468 1469 1470 1471 1472 1473
        # TODO(typhoonzero): should auto adapt GPU count on the machine.
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
G
guru4elephant 已提交
1474
            "FLAGS_rpc_deadline": "30000",  # 5sec to fail fast
1475
            "FLAGS_rpc_retry_bind_port": "50",
1476
            "FLAGS_cudnn_deterministic": "1",
1477
            "FLAGS_rpc_disable_reuse_port": "1",
W
Wu Yi 已提交
1478
            "http_proxy": "",
1479
            "NCCL_P2P_DISABLE": "1",
1480 1481
            "NCCL_SHM_DISABLE": "1",
            "FLAGS_CONVERT_GRAPH_TO_PROGRAM": "1"
1482 1483 1484
        }

        if check_error_log:
1485
            required_envs["GLOG_vmodule"] = \
1486 1487
                "fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \
                "alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \
W
WangXi 已提交
1488
                "sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,gen_nccl_id_op_help=10,nccl_helper=10,grpc_client=10," \
1489
                "grpc_server=10,request_handler_impl=10,section_worker=10"
1490 1491
            required_envs["GLOG_logtostderr"] = "1"

1492 1493 1494 1495
        if os.getenv('NVIDIA_TF32_OVERRIDE', '') is not None:
            required_envs['NVIDIA_TF32_OVERRIDE'] = os.getenv(
                'NVIDIA_TF32_OVERRIDE', '')

1496 1497 1498 1499 1500 1501 1502 1503 1504
        required_envs.update(need_envs)
        return required_envs

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={},
                         log_name=""):
1505
        if self._dygraph and (self._gloo_mode or self._nccl2_mode):
1506
            need_envs.update({"FLAGS_enable_eager_mode": "1"})
1507
            with _test_eager_guard():
1508 1509 1510 1511 1512
                self.check_with_place_func(model_file=model_file,
                                           delta=delta,
                                           check_error_log=check_error_log,
                                           need_envs=need_envs,
                                           log_name=log_name)
1513
            need_envs.update({"FLAGS_enable_eager_mode": "0"})
1514 1515 1516 1517 1518
            self.check_with_place_func(model_file=model_file,
                                       delta=delta,
                                       check_error_log=check_error_log,
                                       need_envs=need_envs,
                                       log_name=log_name)
1519
        else:
1520 1521 1522 1523 1524
            self.check_with_place_func(model_file=model_file,
                                       delta=delta,
                                       check_error_log=check_error_log,
                                       need_envs=need_envs,
                                       log_name=log_name)
1525 1526 1527 1528 1529 1530 1531

    def check_with_place_func(self,
                              model_file,
                              delta=1e-3,
                              check_error_log=False,
                              need_envs={},
                              log_name=""):
1532 1533
        required_envs = self._get_required_envs(check_error_log, need_envs)

X
xiongkun 已提交
1534 1535 1536 1537 1538 1539
        if self._gloo_mode:
            local_losses \
                = self._run_local_gloo(model_file, required_envs,
                                  check_error_log, log_name=log_name)
        else:
            local_losses \
1540
            = self._run_local(model_file, required_envs,
1541 1542
                              check_error_log, log_name=log_name)

W
Wu Yi 已提交
1543
        if self._nccl2_mode:
W
Wu Yi 已提交
1544 1545
            if self._nccl2_reduce_layer:
                tr0_losses, tr1_losses = self._run_cluster_nccl2(
1546 1547
                    model_file,
                    required_envs,
1548 1549
                    update_method="nccl2_reduce_layer",
                    check_error_log=check_error_log,
1550
                    log_name=log_name)
W
Wu Yi 已提交
1551 1552
            else:
                tr0_losses, tr1_losses = self._run_cluster_nccl2(
1553 1554
                    model_file,
                    required_envs,
1555 1556
                    update_method='nccl2',
                    check_error_log=check_error_log,
1557
                    log_name=log_name)
1558 1559 1560 1561 1562 1563 1564
        elif self._bkcl_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file,
                required_envs,
                update_method='bkcl',
                check_error_log=check_error_log,
                log_name=log_name)
X
xiongkun 已提交
1565 1566 1567 1568 1569 1570 1571 1572
        elif self._gloo_mode:
            # gloo mode, cpu only parallel train @xiongkun03
            tr0_losses, tr1_losses = self._run_cluster_gloo(
                model_file,
                required_envs,
                update_method='gloo',
                check_error_log=check_error_log,
                log_name=log_name)
1573 1574 1575 1576 1577 1578 1579
        elif self._hccl_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file,
                required_envs,
                update_method='hccl',
                check_error_log=check_error_log,
                log_name=log_name)
1580 1581 1582 1583 1584 1585 1586
        elif self._cncl_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file,
                required_envs,
                update_method='cncl',
                check_error_log=check_error_log,
                log_name=log_name)
1587
        elif self._pipeline_mode:
1588 1589 1590 1591
            tr0_losses, tr1_losses = self._run_pipeline(model_file,
                                                        required_envs,
                                                        check_error_log,
                                                        log_name=log_name)
W
Wu Yi 已提交
1592
        else:
1593 1594 1595 1596
            tr0_losses, tr1_losses = self._run_cluster(model_file,
                                                       required_envs,
                                                       check_error_log,
                                                       log_name=log_name)
1597 1598

        for step_id in range(RUN_STEP):
W
Wu Yi 已提交
1599 1600 1601
            local_loss = local_losses[step_id]
            tr0_loss = tr0_losses[step_id]
            tr1_loss = tr1_losses[step_id]
1602 1603 1604 1605
            if self._pipeline_mode:
                dist_loss = np.array([tr1_loss])
            else:
                dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
W
Wu Yi 已提交
1606 1607
            print("=======", local_loss, ":", dist_loss[0], "=======")
            self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)
1608 1609 1610 1611 1612 1613 1614

    def check_with_place_multi_cards(self,
                                     model_file,
                                     delta=1e-3,
                                     check_error_log=False,
                                     need_envs={},
                                     log_name=""):
1615

1616 1617 1618 1619 1620 1621
        # need open p2p or shm otherwise multi cards mode will hang
        need_envs.update({"NCCL_P2P_DISABLE": "0", "NCCL_SHM_DISABLE": "0"})

        required_envs = self._get_required_envs(check_error_log, need_envs)

        if self._use_dgc:
1622 1623 1624 1625 1626 1627
            multi_cards_losses = self._run_local(model_file,
                                                 required_envs,
                                                 check_error_log,
                                                 log_name=log_name +
                                                 "_dgc_2cards",
                                                 devices="0,1")
1628 1629

            self._use_dgc = False
1630 1631 1632 1633 1634
            base_losses = self._run_local(model_file,
                                          required_envs,
                                          check_error_log,
                                          log_name=log_name + "_base_2cards",
                                          devices="0,1")
1635 1636 1637 1638 1639 1640 1641 1642

            self._use_dgc = True

            for step_id in range(RUN_STEP):
                base_loss = base_losses[step_id]
                multi_cards_loss = multi_cards_losses[step_id]
                print("=======", base_loss, ":", multi_cards_loss, "=======")
                self.assertAlmostEqual(base_loss, multi_cards_loss, delta=delta)