test_dist_base.py 19.2 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
X
Xin Pan 已提交
16 17 18 19 20 21 22
import time

import unittest
import os
import sys
import signal
import subprocess
23
import six
W
Wu Yi 已提交
24
import argparse
W
Wu Yi 已提交
25 26
import pickle
import numpy as np
T
typhoonzero 已提交
27

28
import paddle.fluid as fluid
29
from paddle.fluid import compiler
30 31

RUN_STEP = 10
32
DEFAULT_BATCH_SIZE = 2
33

T
typhoonzero 已提交
34 35

class TestDistRunnerBase(object):
W
Wu Yi 已提交
36
    def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1):
T
typhoonzero 已提交
37 38 39
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

40
    @staticmethod
W
Wu Yi 已提交
41 42 43 44 45
    def get_transpiler(trainer_id,
                       main_program,
                       pserver_endpoints,
                       trainers,
                       sync_mode,
46 47
                       dc_asgd=False,
                       current_endpoint=None):
T
typhoonzero 已提交
48
        # NOTE: import fluid until runtime, or else forking processes will cause error.
49
        config = fluid.DistributeTranspilerConfig()
W
Wu Yi 已提交
50
        config.enable_dc_asgd = dc_asgd
51
        t = fluid.DistributeTranspiler(config=config)
T
typhoonzero 已提交
52 53 54 55
        t.transpile(
            trainer_id=trainer_id,
            program=main_program,
            pservers=pserver_endpoints,
W
Wu Yi 已提交
56
            trainers=trainers,
57 58
            sync_mode=sync_mode,
            current_endpoint=current_endpoint)
T
typhoonzero 已提交
59 60
        return t

W
Wu Yi 已提交
61
    def run_pserver(self, args):
W
Wu Yi 已提交
62
        self.lr = args.lr
63
        self.get_model(batch_size=args.batch_size)
64
        # NOTE: pserver should not call memory optimize
W
Wu Yi 已提交
65 66
        t = self.get_transpiler(args.trainer_id,
                                fluid.default_main_program(), args.endpoints,
W
Wu Yi 已提交
67
                                args.trainers, args.sync_mode, args.dc_asgd)
W
Wu Yi 已提交
68 69 70
        pserver_prog = t.get_pserver_program(args.current_endpoint)
        startup_prog = t.get_startup_program(args.current_endpoint,
                                             pserver_prog)
Y
Yancey1989 已提交
71

T
typhoonzero 已提交
72 73 74 75 76
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        exe.run(pserver_prog)

77
    def run_trainer(self, args):
W
Wu Yi 已提交
78
        self.lr = args.lr
T
typhoonzero 已提交
79
        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
80
            self.get_model(batch_size=args.batch_size)
81

W
Wu Yi 已提交
82
        if args.mem_opt:
83
            fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
W
Wu Yi 已提交
84
        if args.update_method == "pserver":
W
Wu Yi 已提交
85 86 87
            t = self.get_transpiler(args.trainer_id,
                                    fluid.default_main_program(),
                                    args.endpoints, args.trainers,
W
Wu Yi 已提交
88
                                    args.sync_mode, args.dc_asgd)
T
typhoonzero 已提交
89
            trainer_prog = t.get_trainer_program()
W
Wu Yi 已提交
90 91 92 93 94 95 96 97 98 99 100 101
        elif args.update_method == "nccl2":
            # transpile for nccl2
            config = fluid.DistributeTranspilerConfig()
            config.mode = "nccl2"
            nccl2_t = fluid.DistributeTranspiler(config=config)
            nccl2_t.transpile(
                args.trainer_id,
                program=fluid.default_main_program(),
                startup_program=fluid.default_startup_program(),
                trainers=args.endpoints,
                current_endpoint=args.current_endpoint)
            trainer_prog = fluid.default_main_program()
T
typhoonzero 已提交
102 103 104
        else:
            trainer_prog = fluid.default_main_program()

105 106 107 108 109
        if args.use_cuda:
            place = fluid.CUDAPlace(0)
        else:
            place = fluid.CPUPlace()

110 111
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
T
typhoonzero 已提交
112 113 114 115

        strategy = fluid.ExecutionStrategy()
        strategy.num_threads = 1
        strategy.allow_op_delay = False
116

W
Wu Yi 已提交
117 118 119 120 121 122 123
        build_stra = fluid.BuildStrategy()

        if args.use_reduce:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
        else:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce

X
Xin Pan 已提交
124
        if args.batch_merge_repeat > 1:
X
fix  
Xin Pan 已提交
125
            pass_builder = build_stra._finalize_strategy_and_create_passes()
X
Xin Pan 已提交
126
            mypass = pass_builder.insert_pass(
S
sneaxiy 已提交
127
                len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
128
            mypass.set("num_repeats", args.batch_merge_repeat)
X
Xin Pan 已提交
129

W
Wu Yi 已提交
130
        if args.update_method == "nccl2":
131 132
            build_stra.num_trainers = len(args.endpoints.split(","))
            build_stra.trainer_id = args.trainer_id
W
Wu Yi 已提交
133
        else:
134 135
            build_stra.num_trainers = 1
            build_stra.trainer_id = 0
W
Wu Yi 已提交
136

X
Xin Pan 已提交
137
        binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
W
Wu Yi 已提交
138
            loss_name=avg_cost.name,
W
Wu Yi 已提交
139
            build_strategy=build_stra,
140
            exec_strategy=strategy)
T
typhoonzero 已提交
141 142 143 144 145 146 147

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

        feeder = fluid.DataFeeder(feed_var_list, place)
148
        reader_generator = train_reader()
T
typhoonzero 已提交
149

150 151
        def get_data():
            origin_batch = next(reader_generator)
W
Wu Yi 已提交
152
            if args.update_method != "local" and args.use_reader_alloc:
153 154 155 156 157 158 159
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch
T
typhoonzero 已提交
160

W
Wu Yi 已提交
161
        out_losses = []
162
        for _ in six.moves.xrange(RUN_STEP):
163 164
            loss, = exe.run(binary,
                            fetch_list=[avg_cost.name],
165
                            feed=feeder.feed(get_data()))
W
Wu Yi 已提交
166 167 168 169 170
            out_losses.append(loss[0])
        if six.PY2:
            print(pickle.dumps(out_losses))
        else:
            sys.stdout.buffer.write(pickle.dumps(out_losses))
T
typhoonzero 已提交
171 172 173


def runtime_main(test_class):
W
Wu Yi 已提交
174 175 176 177
    parser = argparse.ArgumentParser(description='Run dist test.')
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer'])
    parser.add_argument('--endpoints', type=str, required=False, default="")
W
Wu Yi 已提交
178 179 180 181 182
    parser.add_argument(
        '--update_method',
        type=str,
        default="local",
        choices=["pserver", "nccl2", "local"])
W
Wu Yi 已提交
183 184 185 186 187 188
    parser.add_argument('--trainer_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
    parser.add_argument(
        '--current_endpoint', type=str, required=False, default="")
    parser.add_argument('--sync_mode', action='store_true')
    parser.add_argument('--mem_opt', action='store_true')
189
    parser.add_argument('--use_cuda', action='store_true')
W
Wu Yi 已提交
190
    parser.add_argument('--use_reduce', action='store_true')
W
Wu Yi 已提交
191
    parser.add_argument('--dc_asgd', action='store_true')
192
    parser.add_argument(
W
Wu Yi 已提交
193
        '--use_reader_alloc', action='store_true', required=False)
194
    parser.add_argument('--batch_size', required=False, type=int, default=2)
W
Wu Yi 已提交
195
    parser.add_argument('--lr', required=False, type=float, default=0.001)
196 197
    parser.add_argument(
        '--batch_merge_repeat', required=False, type=int, default=1)
W
Wu Yi 已提交
198 199

    args = parser.parse_args()
T
typhoonzero 已提交
200 201

    model = test_class()
W
Wu Yi 已提交
202
    if args.role == "pserver" and args.update_method == "pserver":
W
Wu Yi 已提交
203
        model.run_pserver(args)
T
typhoonzero 已提交
204
    else:
205
        model.run_trainer(args)
X
Xin Pan 已提交
206

M
minqiyang 已提交
207

M
minqiyang 已提交
208
import paddle.compat as cpt
Y
Yancey1989 已提交
209 210
import socket
from contextlib import closing
M
minqiyang 已提交
211

X
Xin Pan 已提交
212 213

class TestDistBase(unittest.TestCase):
W
Wu Yi 已提交
214 215 216
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

217 218 219 220 221 222 223 224 225 226 227
    def _after_setup_config(self):
        if self._enforce_place == "CPU":
            self.__use_cuda = False
        elif self._enforce_place == "GPU":
            self.__use_cuda = True
        else:
            if fluid.core.is_compiled_with_cuda():
                self.__use_cuda = True
            else:
                self.__use_cuda = False

X
Xin Pan 已提交
228 229 230
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
Y
Yancey1989 已提交
231
        self._port_set = set()
Y
Yancey1989 已提交
232 233
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
            self._find_free_port(), self._find_free_port())
M
minqiyang 已提交
234
        self._python_interp = sys.executable
W
Wu Yi 已提交
235
        self._sync_mode = True
236
        self._enforce_place = None
W
Wu Yi 已提交
237
        self._mem_opt = False
W
Wu Yi 已提交
238
        self._use_reduce = False
W
Wu Yi 已提交
239
        self._dc_asgd = False  # must use with async mode
240
        self._use_reader_alloc = True
W
Wu Yi 已提交
241
        self._nccl2_mode = False
W
Wu Yi 已提交
242
        self._lr = 0.001
W
Wu Yi 已提交
243
        self._setup_config()
244
        self._after_setup_config()
X
Xin Pan 已提交
245

Y
Yancey1989 已提交
246
    def _find_free_port(self):
Y
Yancey1989 已提交
247 248 249 250 251 252 253 254 255 256 257
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port
Y
Yancey1989 已提交
258

259
    def start_pserver(self, model_file, check_error_log, required_envs):
X
Xin Pan 已提交
260
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
W
Wu Yi 已提交
261
        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
W
Wu Yi 已提交
262
        ps0_cmd = ps_cmd % \
263 264
                  (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                   self._trainers)
W
Wu Yi 已提交
265
        ps1_cmd = ps_cmd % \
266 267
                  (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
                   self._trainers)
W
Wu Yi 已提交
268 269 270 271 272 273 274

        if self._sync_mode:
            ps0_cmd += " --sync_mode"
            ps1_cmd += " --sync_mode"
        if self._mem_opt:
            ps0_cmd += " --mem_opt"
            ps1_cmd += " --mem_opt"
X
Xin Pan 已提交
275

276 277
        print(ps0_cmd)
        print(ps1_cmd)
M
minqiyang 已提交
278 279
        ps0_pipe = open("/tmp/ps0_err.log", "wb")
        ps1_pipe = open("/tmp/ps1_err.log", "wb")
G
gongweibao 已提交
280

X
Xin Pan 已提交
281
        ps0_proc = subprocess.Popen(
282 283 284 285
            ps0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps0_pipe,
            env=required_envs)
X
Xin Pan 已提交
286
        ps1_proc = subprocess.Popen(
287 288 289 290
            ps1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps1_pipe,
            env=required_envs)
G
gongweibao 已提交
291

292
        return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
X
Xin Pan 已提交
293

294 295 296 297 298 299
    def _run_local(self,
                   model,
                   envs,
                   check_error_log=False,
                   batch_size=DEFAULT_BATCH_SIZE,
                   batch_merge_repeat=1):
G
gongweibao 已提交
300

W
Wu Yi 已提交
301 302
        cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model,
                                                self._lr)
303 304 305 306
        if batch_size != DEFAULT_BATCH_SIZE:
            cmd += " --batch_size %d" % batch_size
        if batch_merge_repeat > 1:
            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
307

308
        if self.__use_cuda:
309 310 311 312 313
            cmd += " --use_cuda"
            env_local = {"CUDA_VISIBLE_DEVICES": "0"}
        else:
            env_local = {'CPU_NUM': '1'}

W
Wu Yi 已提交
314 315
        env_local.update(envs)
        print("local_cmd: {}, env: {}".format(cmd, env_local))
G
gongweibao 已提交
316

317
        if check_error_log:
318
            err_log = open("/tmp/trainer.err.log", "wb")
G
gongweibao 已提交
319
            local_proc = subprocess.Popen(
320
                cmd.split(" "),
G
gongweibao 已提交
321
                stdout=subprocess.PIPE,
322
                stderr=err_log,
W
Wu Yi 已提交
323
                env=env_local)
G
gongweibao 已提交
324 325
        else:
            local_proc = subprocess.Popen(
326
                cmd.split(" "),
G
gongweibao 已提交
327
                stdout=subprocess.PIPE,
328
                stderr=subprocess.PIPE,
W
Wu Yi 已提交
329
                env=env_local)
G
gongweibao 已提交
330

331 332 333 334 335 336
        local_out, local_err = local_proc.communicate()

        if check_error_log:
            err_log.close()

        sys.stderr.write('local_stderr: %s\n' % local_err)
W
Wu Yi 已提交
337
        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
X
Xin Pan 已提交
338

W
Wu Yi 已提交
339
        return pickle.loads(local_out)
340 341

    def _run_cluster(self, model, envs, check_error_log):
X
Xin Pan 已提交
342
        # Run dist train to compare with local results
343 344
        ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
                                                          check_error_log, envs)
W
Wu Yi 已提交
345

X
Xin Pan 已提交
346
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
347

W
Wu Yi 已提交
348
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f"
W
Wu Yi 已提交
349
        tr0_cmd = tr_cmd % \
350
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
351
                   0, ps0_ep, self._trainers, self._lr)
W
Wu Yi 已提交
352
        tr1_cmd = tr_cmd % \
353
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
354
                   1, ps1_ep, self._trainers, self._lr)
W
Wu Yi 已提交
355 356 357 358 359 360 361 362 363 364

        if self._sync_mode:
            tr0_cmd += " --sync_mode"
            tr1_cmd += " --sync_mode"
        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
365 366 367
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
368
        if self.__use_cuda:
369 370 371 372 373 374 375 376 377 378
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)
X
Xin Pan 已提交
379

W
Wu Yi 已提交
380 381
        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
382 383
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")
G
gongweibao 已提交
384

X
Xin Pan 已提交
385
        tr0_proc = subprocess.Popen(
W
Wu Yi 已提交
386
            tr0_cmd.strip().split(" "),
X
Xin Pan 已提交
387
            stdout=subprocess.PIPE,
G
gongweibao 已提交
388
            stderr=tr0_pipe,
X
Xin Pan 已提交
389 390
            env=env0)
        tr1_proc = subprocess.Popen(
W
Wu Yi 已提交
391
            tr1_cmd.strip().split(" "),
X
Xin Pan 已提交
392
            stdout=subprocess.PIPE,
G
gongweibao 已提交
393
            stderr=tr1_pipe,
X
Xin Pan 已提交
394 395
            env=env1)

396 397 398 399 400 401 402 403 404 405 406 407
        # Wait until trainer process terminate
        while True:
            stat0 = tr0_proc.poll()
            time.sleep(0.1)
            if stat0 is not None:
                break
        while True:
            stat1 = tr1_proc.poll()
            time.sleep(0.1)
            if stat1 is not None:
                break

408 409
        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
X
Xin Pan 已提交
410

G
gongweibao 已提交
411
        # close trainer file
412 413 414 415
        tr0_pipe.close()
        tr1_pipe.close()
        ps0_pipe.close()
        ps1_pipe.close()
W
Wu Yi 已提交
416

W
Wu Yi 已提交
417 418
        ps0.terminate()
        ps1.terminate()
T
typhoonzero 已提交
419

420 421 422 423 424 425
        # print server log
        with open("/tmp/ps0_err.log", "r") as fn:
            sys.stderr.write("ps0 stderr: %s\n" % fn.read())
        with open("/tmp/ps1_err.log", "r") as fn:
            sys.stderr.write("ps1 stderr: %s\n" % fn.read())

426
        # print log
427 428 429 430 431 432 433 434
        if stat0 == 0:
            sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
        with open("/tmp/tr0_err.log", "r") as fn:
            sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
        if stat1 == 0:
            sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
        with open("/tmp/tr1_err.log", "r") as fn:
            sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
435

W
Wu Yi 已提交
436 437 438 439 440 441 442
        return pickle.loads(tr0_out), pickle.loads(tr1_out)

    def _run_cluster_nccl2(self, model, envs, check_error_log):
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints

W
Wu Yi 已提交
443
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
W
Wu Yi 已提交
444 445
        tr0_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
446
                   0, w0_ep, self._lr)
W
Wu Yi 已提交
447 448
        tr1_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
449
                   1, w1_ep, self._lr)
W
Wu Yi 已提交
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500

        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
        if self.__use_cuda:
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)

        print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")

        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0)
        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1)

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()

        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()

        # print log
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
        sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)

W
Wu Yi 已提交
501
        return pickle.loads(tr0_out), pickle.loads(tr1_out)
502 503 504 505 506 507 508 509 510 511 512 513

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={}):
        # TODO(typhoonzero): should auto adapt GPU count on the machine.
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
514
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
515
            "FLAGS_cudnn_deterministic": "1",
W
Wu Yi 已提交
516 517
            "http_proxy": "",
            "NCCL_P2P_DISABLE": "1"
518 519 520 521 522
        }

        required_envs.update(need_envs)

        if check_error_log:
W
Wu Yi 已提交
523
            required_envs["GLOG_v"] = "3"
524 525 526 527 528
            required_envs["GLOG_logtostderr"] = "1"

        local_losses\
            = self._run_local(model_file, required_envs,
                                       check_error_log)
W
Wu Yi 已提交
529 530 531 532 533 534
        if self._nccl2_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file, required_envs, check_error_log)
        else:
            tr0_losses, tr1_losses = self._run_cluster(
                model_file, required_envs, check_error_log)
535 536

        for step_id in range(RUN_STEP):
W
Wu Yi 已提交
537 538 539 540 541 542
            local_loss = local_losses[step_id]
            tr0_loss = tr0_losses[step_id]
            tr1_loss = tr1_losses[step_id]
            dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
            print("=======", local_loss, ":", dist_loss[0], "=======")
            self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)