test_dist_base.py 19.1 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
X
Xin Pan 已提交
16 17 18 19 20 21 22
import time

import unittest
import os
import sys
import signal
import subprocess
23
import six
W
Wu Yi 已提交
24
import argparse
W
Wu Yi 已提交
25 26
import pickle
import numpy as np
T
typhoonzero 已提交
27

28 29 30
import paddle.fluid as fluid

RUN_STEP = 10
31
DEFAULT_BATCH_SIZE = 2
32

T
typhoonzero 已提交
33 34

class TestDistRunnerBase(object):
W
Wu Yi 已提交
35
    def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1):
T
typhoonzero 已提交
36 37 38
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

39
    @staticmethod
W
Wu Yi 已提交
40 41 42 43 44 45
    def get_transpiler(trainer_id,
                       main_program,
                       pserver_endpoints,
                       trainers,
                       sync_mode,
                       dc_asgd=False):
T
typhoonzero 已提交
46
        # NOTE: import fluid until runtime, or else forking processes will cause error.
47
        config = fluid.DistributeTranspilerConfig()
W
Wu Yi 已提交
48
        config.enable_dc_asgd = dc_asgd
49
        t = fluid.DistributeTranspiler(config=config)
T
typhoonzero 已提交
50 51 52 53
        t.transpile(
            trainer_id=trainer_id,
            program=main_program,
            pservers=pserver_endpoints,
W
Wu Yi 已提交
54 55
            trainers=trainers,
            sync_mode=sync_mode)
T
typhoonzero 已提交
56 57
        return t

W
Wu Yi 已提交
58
    def run_pserver(self, args):
W
Wu Yi 已提交
59
        self.lr = args.lr
60
        self.get_model(batch_size=args.batch_size)
61
        # NOTE: pserver should not call memory optimize
W
Wu Yi 已提交
62 63
        t = self.get_transpiler(args.trainer_id,
                                fluid.default_main_program(), args.endpoints,
W
Wu Yi 已提交
64
                                args.trainers, args.sync_mode, args.dc_asgd)
W
Wu Yi 已提交
65 66 67
        pserver_prog = t.get_pserver_program(args.current_endpoint)
        startup_prog = t.get_startup_program(args.current_endpoint,
                                             pserver_prog)
Y
Yancey1989 已提交
68

T
typhoonzero 已提交
69 70 71 72 73
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        exe.run(pserver_prog)

74
    def run_trainer(self, args):
W
Wu Yi 已提交
75
        self.lr = args.lr
T
typhoonzero 已提交
76
        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
77
            self.get_model(batch_size=args.batch_size)
78

W
Wu Yi 已提交
79
        if args.mem_opt:
80
            fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
W
Wu Yi 已提交
81
        if args.update_method == "pserver":
W
Wu Yi 已提交
82 83 84
            t = self.get_transpiler(args.trainer_id,
                                    fluid.default_main_program(),
                                    args.endpoints, args.trainers,
W
Wu Yi 已提交
85
                                    args.sync_mode, args.dc_asgd)
T
typhoonzero 已提交
86
            trainer_prog = t.get_trainer_program()
W
Wu Yi 已提交
87 88 89 90 91 92 93 94 95 96 97 98
        elif args.update_method == "nccl2":
            # transpile for nccl2
            config = fluid.DistributeTranspilerConfig()
            config.mode = "nccl2"
            nccl2_t = fluid.DistributeTranspiler(config=config)
            nccl2_t.transpile(
                args.trainer_id,
                program=fluid.default_main_program(),
                startup_program=fluid.default_startup_program(),
                trainers=args.endpoints,
                current_endpoint=args.current_endpoint)
            trainer_prog = fluid.default_main_program()
T
typhoonzero 已提交
99 100 101
        else:
            trainer_prog = fluid.default_main_program()

102 103 104 105 106
        if args.use_cuda:
            place = fluid.CUDAPlace(0)
        else:
            place = fluid.CPUPlace()

T
typhoonzero 已提交
107 108 109 110 111 112
        startup_exe = fluid.Executor(place)
        startup_exe.run(fluid.default_startup_program())

        strategy = fluid.ExecutionStrategy()
        strategy.num_threads = 1
        strategy.allow_op_delay = False
113

W
Wu Yi 已提交
114 115 116 117 118 119 120
        build_stra = fluid.BuildStrategy()

        if args.use_reduce:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
        else:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce

X
Xin Pan 已提交
121
        if args.batch_merge_repeat > 1:
X
fix  
Xin Pan 已提交
122
            pass_builder = build_stra._finalize_strategy_and_create_passes()
X
Xin Pan 已提交
123 124 125 126
            mypass = pass_builder.insert_pass(
                len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
            mypass.set_int("num_repeats", args.batch_merge_repeat)

W
Wu Yi 已提交
127 128 129 130 131 132 133
        if args.update_method == "nccl2":
            num_trainers = len(args.endpoints.split(","))
            trainer_id = args.trainer_id
        else:
            num_trainers = 1
            trainer_id = 0

T
typhoonzero 已提交
134
        exe = fluid.ParallelExecutor(
135
            args.use_cuda,
W
Wu Yi 已提交
136 137
            loss_name=avg_cost.name,
            exec_strategy=strategy,
W
Wu Yi 已提交
138 139 140
            build_strategy=build_stra,
            num_trainers=num_trainers,
            trainer_id=trainer_id)
T
typhoonzero 已提交
141 142 143 144 145 146 147

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

        feeder = fluid.DataFeeder(feed_var_list, place)
148
        reader_generator = train_reader()
T
typhoonzero 已提交
149

150 151
        def get_data():
            origin_batch = next(reader_generator)
W
Wu Yi 已提交
152
            if args.update_method != "local" and args.use_reader_alloc:
153 154 155 156 157 158 159
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch
T
typhoonzero 已提交
160

W
Wu Yi 已提交
161
        out_losses = []
162 163 164
        for _ in six.moves.xrange(RUN_STEP):
            loss, = exe.run(fetch_list=[avg_cost.name],
                            feed=feeder.feed(get_data()))
W
Wu Yi 已提交
165 166 167 168 169
            out_losses.append(loss[0])
        if six.PY2:
            print(pickle.dumps(out_losses))
        else:
            sys.stdout.buffer.write(pickle.dumps(out_losses))
T
typhoonzero 已提交
170 171 172


def runtime_main(test_class):
W
Wu Yi 已提交
173 174 175 176
    parser = argparse.ArgumentParser(description='Run dist test.')
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer'])
    parser.add_argument('--endpoints', type=str, required=False, default="")
W
Wu Yi 已提交
177 178 179 180 181
    parser.add_argument(
        '--update_method',
        type=str,
        default="local",
        choices=["pserver", "nccl2", "local"])
W
Wu Yi 已提交
182 183 184 185 186 187
    parser.add_argument('--trainer_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
    parser.add_argument(
        '--current_endpoint', type=str, required=False, default="")
    parser.add_argument('--sync_mode', action='store_true')
    parser.add_argument('--mem_opt', action='store_true')
188
    parser.add_argument('--use_cuda', action='store_true')
W
Wu Yi 已提交
189
    parser.add_argument('--use_reduce', action='store_true')
W
Wu Yi 已提交
190
    parser.add_argument('--dc_asgd', action='store_true')
191
    parser.add_argument(
W
Wu Yi 已提交
192
        '--use_reader_alloc', action='store_true', required=False)
193
    parser.add_argument('--batch_size', required=False, type=int, default=2)
W
Wu Yi 已提交
194
    parser.add_argument('--lr', required=False, type=float, default=0.001)
195 196
    parser.add_argument(
        '--batch_merge_repeat', required=False, type=int, default=1)
W
Wu Yi 已提交
197 198

    args = parser.parse_args()
T
typhoonzero 已提交
199 200

    model = test_class()
W
Wu Yi 已提交
201
    if args.role == "pserver" and args.update_method == "pserver":
W
Wu Yi 已提交
202
        model.run_pserver(args)
T
typhoonzero 已提交
203
    else:
204
        model.run_trainer(args)
X
Xin Pan 已提交
205

M
minqiyang 已提交
206

M
minqiyang 已提交
207
import paddle.compat as cpt
Y
Yancey1989 已提交
208 209
import socket
from contextlib import closing
M
minqiyang 已提交
210

X
Xin Pan 已提交
211 212

class TestDistBase(unittest.TestCase):
W
Wu Yi 已提交
213 214 215
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

216 217 218 219 220 221 222 223 224 225 226
    def _after_setup_config(self):
        if self._enforce_place == "CPU":
            self.__use_cuda = False
        elif self._enforce_place == "GPU":
            self.__use_cuda = True
        else:
            if fluid.core.is_compiled_with_cuda():
                self.__use_cuda = True
            else:
                self.__use_cuda = False

X
Xin Pan 已提交
227 228 229
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
Y
Yancey1989 已提交
230
        self._port_set = set()
Y
Yancey1989 已提交
231 232
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
            self._find_free_port(), self._find_free_port())
M
minqiyang 已提交
233
        self._python_interp = sys.executable
W
Wu Yi 已提交
234
        self._sync_mode = True
235
        self._enforce_place = None
W
Wu Yi 已提交
236
        self._mem_opt = False
W
Wu Yi 已提交
237
        self._use_reduce = False
W
Wu Yi 已提交
238
        self._dc_asgd = False  # must use with async mode
239
        self._use_reader_alloc = True
W
Wu Yi 已提交
240
        self._nccl2_mode = False
W
Wu Yi 已提交
241
        self._lr = 0.001
W
Wu Yi 已提交
242
        self._setup_config()
243
        self._after_setup_config()
X
Xin Pan 已提交
244

Y
Yancey1989 已提交
245
    def _find_free_port(self):
Y
Yancey1989 已提交
246 247 248 249 250 251 252 253 254 255 256
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port
Y
Yancey1989 已提交
257

258
    def start_pserver(self, model_file, check_error_log, required_envs):
X
Xin Pan 已提交
259
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
W
Wu Yi 已提交
260
        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
W
Wu Yi 已提交
261
        ps0_cmd = ps_cmd % \
262 263
                  (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                   self._trainers)
W
Wu Yi 已提交
264
        ps1_cmd = ps_cmd % \
265 266
                  (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
                   self._trainers)
W
Wu Yi 已提交
267 268 269 270 271 272 273

        if self._sync_mode:
            ps0_cmd += " --sync_mode"
            ps1_cmd += " --sync_mode"
        if self._mem_opt:
            ps0_cmd += " --mem_opt"
            ps1_cmd += " --mem_opt"
X
Xin Pan 已提交
274

275 276
        print(ps0_cmd)
        print(ps1_cmd)
M
minqiyang 已提交
277 278
        ps0_pipe = open("/tmp/ps0_err.log", "wb")
        ps1_pipe = open("/tmp/ps1_err.log", "wb")
G
gongweibao 已提交
279

X
Xin Pan 已提交
280
        ps0_proc = subprocess.Popen(
281 282 283 284
            ps0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps0_pipe,
            env=required_envs)
X
Xin Pan 已提交
285
        ps1_proc = subprocess.Popen(
286 287 288 289
            ps1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps1_pipe,
            env=required_envs)
G
gongweibao 已提交
290

291
        return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
X
Xin Pan 已提交
292

293 294 295 296 297 298
    def _run_local(self,
                   model,
                   envs,
                   check_error_log=False,
                   batch_size=DEFAULT_BATCH_SIZE,
                   batch_merge_repeat=1):
G
gongweibao 已提交
299

W
Wu Yi 已提交
300 301
        cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model,
                                                self._lr)
302 303 304 305
        if batch_size != DEFAULT_BATCH_SIZE:
            cmd += " --batch_size %d" % batch_size
        if batch_merge_repeat > 1:
            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
306

307
        if self.__use_cuda:
308 309 310 311 312
            cmd += " --use_cuda"
            env_local = {"CUDA_VISIBLE_DEVICES": "0"}
        else:
            env_local = {'CPU_NUM': '1'}

W
Wu Yi 已提交
313 314
        env_local.update(envs)
        print("local_cmd: {}, env: {}".format(cmd, env_local))
G
gongweibao 已提交
315

316
        if check_error_log:
317
            err_log = open("/tmp/trainer.err.log", "wb")
G
gongweibao 已提交
318
            local_proc = subprocess.Popen(
319
                cmd.split(" "),
G
gongweibao 已提交
320
                stdout=subprocess.PIPE,
321
                stderr=err_log,
W
Wu Yi 已提交
322
                env=env_local)
G
gongweibao 已提交
323 324
        else:
            local_proc = subprocess.Popen(
325
                cmd.split(" "),
G
gongweibao 已提交
326
                stdout=subprocess.PIPE,
327
                stderr=subprocess.PIPE,
W
Wu Yi 已提交
328
                env=env_local)
G
gongweibao 已提交
329

330 331 332 333 334 335
        local_out, local_err = local_proc.communicate()

        if check_error_log:
            err_log.close()

        sys.stderr.write('local_stderr: %s\n' % local_err)
W
Wu Yi 已提交
336
        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
X
Xin Pan 已提交
337

W
Wu Yi 已提交
338
        return pickle.loads(local_out)
339 340

    def _run_cluster(self, model, envs, check_error_log):
X
Xin Pan 已提交
341
        # Run dist train to compare with local results
342 343
        ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
                                                          check_error_log, envs)
W
Wu Yi 已提交
344

X
Xin Pan 已提交
345
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
346

W
Wu Yi 已提交
347
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f"
W
Wu Yi 已提交
348
        tr0_cmd = tr_cmd % \
349
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
350
                   0, ps0_ep, self._trainers, self._lr)
W
Wu Yi 已提交
351
        tr1_cmd = tr_cmd % \
352
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
353
                   1, ps1_ep, self._trainers, self._lr)
W
Wu Yi 已提交
354 355 356 357 358 359 360 361 362 363

        if self._sync_mode:
            tr0_cmd += " --sync_mode"
            tr1_cmd += " --sync_mode"
        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
364 365 366
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
367
        if self.__use_cuda:
368 369 370 371 372 373 374 375 376 377
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)
X
Xin Pan 已提交
378

W
Wu Yi 已提交
379 380
        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
381 382
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")
G
gongweibao 已提交
383

X
Xin Pan 已提交
384
        tr0_proc = subprocess.Popen(
W
Wu Yi 已提交
385
            tr0_cmd.strip().split(" "),
X
Xin Pan 已提交
386
            stdout=subprocess.PIPE,
G
gongweibao 已提交
387
            stderr=tr0_pipe,
X
Xin Pan 已提交
388 389
            env=env0)
        tr1_proc = subprocess.Popen(
W
Wu Yi 已提交
390
            tr1_cmd.strip().split(" "),
X
Xin Pan 已提交
391
            stdout=subprocess.PIPE,
G
gongweibao 已提交
392
            stderr=tr1_pipe,
X
Xin Pan 已提交
393 394
            env=env1)

395 396 397 398 399 400 401 402 403 404 405 406
        # Wait until trainer process terminate
        while True:
            stat0 = tr0_proc.poll()
            time.sleep(0.1)
            if stat0 is not None:
                break
        while True:
            stat1 = tr1_proc.poll()
            time.sleep(0.1)
            if stat1 is not None:
                break

407 408
        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
X
Xin Pan 已提交
409

G
gongweibao 已提交
410
        # close trainer file
411 412 413 414
        tr0_pipe.close()
        tr1_pipe.close()
        ps0_pipe.close()
        ps1_pipe.close()
W
Wu Yi 已提交
415

W
Wu Yi 已提交
416 417
        ps0.terminate()
        ps1.terminate()
T
typhoonzero 已提交
418

419 420 421 422 423 424
        # print server log
        with open("/tmp/ps0_err.log", "r") as fn:
            sys.stderr.write("ps0 stderr: %s\n" % fn.read())
        with open("/tmp/ps1_err.log", "r") as fn:
            sys.stderr.write("ps1 stderr: %s\n" % fn.read())

425
        # print log
426 427 428 429 430 431 432 433
        if stat0 == 0:
            sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
        with open("/tmp/tr0_err.log", "r") as fn:
            sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
        if stat1 == 0:
            sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
        with open("/tmp/tr1_err.log", "r") as fn:
            sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
434

W
Wu Yi 已提交
435 436 437 438 439 440 441
        return pickle.loads(tr0_out), pickle.loads(tr1_out)

    def _run_cluster_nccl2(self, model, envs, check_error_log):
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints

W
Wu Yi 已提交
442
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
W
Wu Yi 已提交
443 444
        tr0_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
445
                   0, w0_ep, self._lr)
W
Wu Yi 已提交
446 447
        tr1_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
448
                   1, w1_ep, self._lr)
W
Wu Yi 已提交
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499

        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
        if self.__use_cuda:
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)

        print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")

        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0)
        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1)

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()

        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()

        # print log
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
        sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)

W
Wu Yi 已提交
500
        return pickle.loads(tr0_out), pickle.loads(tr1_out)
501 502 503 504 505 506 507 508 509 510 511 512

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={}):
        # TODO(typhoonzero): should auto adapt GPU count on the machine.
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
513
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
514
            "FLAGS_cudnn_deterministic": "1",
W
Wu Yi 已提交
515 516
            "http_proxy": "",
            "NCCL_P2P_DISABLE": "1"
517 518 519 520 521
        }

        required_envs.update(need_envs)

        if check_error_log:
W
Wu Yi 已提交
522
            required_envs["GLOG_v"] = "3"
523 524 525 526 527
            required_envs["GLOG_logtostderr"] = "1"

        local_losses\
            = self._run_local(model_file, required_envs,
                                       check_error_log)
W
Wu Yi 已提交
528 529 530 531 532 533
        if self._nccl2_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file, required_envs, check_error_log)
        else:
            tr0_losses, tr1_losses = self._run_cluster(
                model_file, required_envs, check_error_log)
534 535

        for step_id in range(RUN_STEP):
W
Wu Yi 已提交
536 537 538 539 540 541
            local_loss = local_losses[step_id]
            tr0_loss = tr0_losses[step_id]
            tr1_loss = tr1_losses[step_id]
            dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
            print("=======", local_loss, ":", dist_loss[0], "=======")
            self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)