test_dist_base.py 19.2 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
X
Xin Pan 已提交
16 17 18 19 20 21 22
import time

import unittest
import os
import sys
import signal
import subprocess
23
import six
W
Wu Yi 已提交
24
import argparse
W
Wu Yi 已提交
25 26
import pickle
import numpy as np
T
typhoonzero 已提交
27

28
import paddle.fluid as fluid
29
from paddle.fluid import compiler
30 31

RUN_STEP = 10
32
DEFAULT_BATCH_SIZE = 2
33

T
typhoonzero 已提交
34 35

class TestDistRunnerBase(object):
W
Wu Yi 已提交
36
    def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1):
T
typhoonzero 已提交
37 38 39
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

40
    @staticmethod
W
Wu Yi 已提交
41 42 43 44 45 46
    def get_transpiler(trainer_id,
                       main_program,
                       pserver_endpoints,
                       trainers,
                       sync_mode,
                       dc_asgd=False):
T
typhoonzero 已提交
47
        # NOTE: import fluid until runtime, or else forking processes will cause error.
48
        config = fluid.DistributeTranspilerConfig()
W
Wu Yi 已提交
49
        config.enable_dc_asgd = dc_asgd
50
        t = fluid.DistributeTranspiler(config=config)
T
typhoonzero 已提交
51 52 53 54
        t.transpile(
            trainer_id=trainer_id,
            program=main_program,
            pservers=pserver_endpoints,
W
Wu Yi 已提交
55 56
            trainers=trainers,
            sync_mode=sync_mode)
T
typhoonzero 已提交
57 58
        return t

W
Wu Yi 已提交
59
    def run_pserver(self, args):
W
Wu Yi 已提交
60
        self.lr = args.lr
61
        self.get_model(batch_size=args.batch_size)
62
        # NOTE: pserver should not call memory optimize
W
Wu Yi 已提交
63 64
        t = self.get_transpiler(args.trainer_id,
                                fluid.default_main_program(), args.endpoints,
W
Wu Yi 已提交
65
                                args.trainers, args.sync_mode, args.dc_asgd)
W
Wu Yi 已提交
66 67 68
        pserver_prog = t.get_pserver_program(args.current_endpoint)
        startup_prog = t.get_startup_program(args.current_endpoint,
                                             pserver_prog)
Y
Yancey1989 已提交
69

T
typhoonzero 已提交
70 71 72 73 74
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        exe.run(pserver_prog)

75
    def run_trainer(self, args):
W
Wu Yi 已提交
76
        self.lr = args.lr
T
typhoonzero 已提交
77
        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
78
            self.get_model(batch_size=args.batch_size)
79

W
Wu Yi 已提交
80
        if args.mem_opt:
81
            fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
W
Wu Yi 已提交
82
        if args.update_method == "pserver":
W
Wu Yi 已提交
83 84 85
            t = self.get_transpiler(args.trainer_id,
                                    fluid.default_main_program(),
                                    args.endpoints, args.trainers,
W
Wu Yi 已提交
86
                                    args.sync_mode, args.dc_asgd)
T
typhoonzero 已提交
87
            trainer_prog = t.get_trainer_program()
W
Wu Yi 已提交
88 89 90 91 92 93 94 95 96 97 98 99
        elif args.update_method == "nccl2":
            # transpile for nccl2
            config = fluid.DistributeTranspilerConfig()
            config.mode = "nccl2"
            nccl2_t = fluid.DistributeTranspiler(config=config)
            nccl2_t.transpile(
                args.trainer_id,
                program=fluid.default_main_program(),
                startup_program=fluid.default_startup_program(),
                trainers=args.endpoints,
                current_endpoint=args.current_endpoint)
            trainer_prog = fluid.default_main_program()
T
typhoonzero 已提交
100 101 102
        else:
            trainer_prog = fluid.default_main_program()

103 104 105 106 107
        if args.use_cuda:
            place = fluid.CUDAPlace(0)
        else:
            place = fluid.CPUPlace()

108 109
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
T
typhoonzero 已提交
110 111 112 113

        strategy = fluid.ExecutionStrategy()
        strategy.num_threads = 1
        strategy.allow_op_delay = False
114

W
Wu Yi 已提交
115 116 117 118 119 120 121
        build_stra = fluid.BuildStrategy()

        if args.use_reduce:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
        else:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce

X
Xin Pan 已提交
122
        if args.batch_merge_repeat > 1:
X
fix  
Xin Pan 已提交
123
            pass_builder = build_stra._finalize_strategy_and_create_passes()
X
Xin Pan 已提交
124
            mypass = pass_builder.insert_pass(
S
sneaxiy 已提交
125
                len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
X
Xin Pan 已提交
126 127
            mypass.set_int("num_repeats", args.batch_merge_repeat)

W
Wu Yi 已提交
128
        if args.update_method == "nccl2":
129 130
            build_stra.num_trainers = len(args.endpoints.split(","))
            build_stra.trainer_id = args.trainer_id
W
Wu Yi 已提交
131
        else:
132 133
            build_stra.num_trainers = 1
            build_stra.trainer_id = 0
W
Wu Yi 已提交
134

X
Xin Pan 已提交
135
        binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
W
Wu Yi 已提交
136
            loss_name=avg_cost.name,
W
Wu Yi 已提交
137
            build_strategy=build_stra,
138
            exec_strategy=strategy)
T
typhoonzero 已提交
139 140 141 142 143 144 145

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

        feeder = fluid.DataFeeder(feed_var_list, place)
146
        reader_generator = train_reader()
T
typhoonzero 已提交
147

148 149
        def get_data():
            origin_batch = next(reader_generator)
W
Wu Yi 已提交
150
            if args.update_method != "local" and args.use_reader_alloc:
151 152 153 154 155 156 157
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch
T
typhoonzero 已提交
158

W
Wu Yi 已提交
159
        out_losses = []
160
        for _ in six.moves.xrange(RUN_STEP):
161 162
            loss, = exe.run(binary,
                            fetch_list=[avg_cost.name],
163
                            feed=feeder.feed(get_data()))
W
Wu Yi 已提交
164 165 166 167 168
            out_losses.append(loss[0])
        if six.PY2:
            print(pickle.dumps(out_losses))
        else:
            sys.stdout.buffer.write(pickle.dumps(out_losses))
T
typhoonzero 已提交
169 170 171


def runtime_main(test_class):
W
Wu Yi 已提交
172 173 174 175
    parser = argparse.ArgumentParser(description='Run dist test.')
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer'])
    parser.add_argument('--endpoints', type=str, required=False, default="")
W
Wu Yi 已提交
176 177 178 179 180
    parser.add_argument(
        '--update_method',
        type=str,
        default="local",
        choices=["pserver", "nccl2", "local"])
W
Wu Yi 已提交
181 182 183 184 185 186
    parser.add_argument('--trainer_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
    parser.add_argument(
        '--current_endpoint', type=str, required=False, default="")
    parser.add_argument('--sync_mode', action='store_true')
    parser.add_argument('--mem_opt', action='store_true')
187
    parser.add_argument('--use_cuda', action='store_true')
W
Wu Yi 已提交
188
    parser.add_argument('--use_reduce', action='store_true')
W
Wu Yi 已提交
189
    parser.add_argument('--dc_asgd', action='store_true')
190
    parser.add_argument(
W
Wu Yi 已提交
191
        '--use_reader_alloc', action='store_true', required=False)
192
    parser.add_argument('--batch_size', required=False, type=int, default=2)
W
Wu Yi 已提交
193
    parser.add_argument('--lr', required=False, type=float, default=0.001)
194 195
    parser.add_argument(
        '--batch_merge_repeat', required=False, type=int, default=1)
W
Wu Yi 已提交
196 197

    args = parser.parse_args()
T
typhoonzero 已提交
198 199

    model = test_class()
W
Wu Yi 已提交
200
    if args.role == "pserver" and args.update_method == "pserver":
W
Wu Yi 已提交
201
        model.run_pserver(args)
T
typhoonzero 已提交
202
    else:
203
        model.run_trainer(args)
X
Xin Pan 已提交
204

M
minqiyang 已提交
205

M
minqiyang 已提交
206
import paddle.compat as cpt
Y
Yancey1989 已提交
207 208
import socket
from contextlib import closing
M
minqiyang 已提交
209

X
Xin Pan 已提交
210 211

class TestDistBase(unittest.TestCase):
W
Wu Yi 已提交
212 213 214
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

215 216 217 218 219 220 221 222 223 224 225
    def _after_setup_config(self):
        if self._enforce_place == "CPU":
            self.__use_cuda = False
        elif self._enforce_place == "GPU":
            self.__use_cuda = True
        else:
            if fluid.core.is_compiled_with_cuda():
                self.__use_cuda = True
            else:
                self.__use_cuda = False

X
Xin Pan 已提交
226 227 228
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
Y
Yancey1989 已提交
229
        self._port_set = set()
Y
Yancey1989 已提交
230 231
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
            self._find_free_port(), self._find_free_port())
M
minqiyang 已提交
232
        self._python_interp = sys.executable
W
Wu Yi 已提交
233
        self._sync_mode = True
234
        self._enforce_place = None
W
Wu Yi 已提交
235
        self._mem_opt = False
W
Wu Yi 已提交
236
        self._use_reduce = False
W
Wu Yi 已提交
237
        self._dc_asgd = False  # must use with async mode
238
        self._use_reader_alloc = True
W
Wu Yi 已提交
239
        self._nccl2_mode = False
W
Wu Yi 已提交
240
        self._lr = 0.001
W
Wu Yi 已提交
241
        self._setup_config()
242
        self._after_setup_config()
X
Xin Pan 已提交
243

Y
Yancey1989 已提交
244
    def _find_free_port(self):
Y
Yancey1989 已提交
245 246 247 248 249 250 251 252 253 254 255
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port
Y
Yancey1989 已提交
256

257
    def start_pserver(self, model_file, check_error_log, required_envs):
X
Xin Pan 已提交
258
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
W
Wu Yi 已提交
259
        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
W
Wu Yi 已提交
260
        ps0_cmd = ps_cmd % \
261 262
                  (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                   self._trainers)
W
Wu Yi 已提交
263
        ps1_cmd = ps_cmd % \
264 265
                  (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
                   self._trainers)
W
Wu Yi 已提交
266 267 268 269 270 271 272

        if self._sync_mode:
            ps0_cmd += " --sync_mode"
            ps1_cmd += " --sync_mode"
        if self._mem_opt:
            ps0_cmd += " --mem_opt"
            ps1_cmd += " --mem_opt"
X
Xin Pan 已提交
273

274 275
        print(ps0_cmd)
        print(ps1_cmd)
M
minqiyang 已提交
276 277
        ps0_pipe = open("/tmp/ps0_err.log", "wb")
        ps1_pipe = open("/tmp/ps1_err.log", "wb")
G
gongweibao 已提交
278

X
Xin Pan 已提交
279
        ps0_proc = subprocess.Popen(
280 281 282 283
            ps0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps0_pipe,
            env=required_envs)
X
Xin Pan 已提交
284
        ps1_proc = subprocess.Popen(
285 286 287 288
            ps1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps1_pipe,
            env=required_envs)
G
gongweibao 已提交
289

290
        return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
X
Xin Pan 已提交
291

292 293 294 295 296 297
    def _run_local(self,
                   model,
                   envs,
                   check_error_log=False,
                   batch_size=DEFAULT_BATCH_SIZE,
                   batch_merge_repeat=1):
G
gongweibao 已提交
298

W
Wu Yi 已提交
299 300
        cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model,
                                                self._lr)
301 302 303 304
        if batch_size != DEFAULT_BATCH_SIZE:
            cmd += " --batch_size %d" % batch_size
        if batch_merge_repeat > 1:
            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
305

306
        if self.__use_cuda:
307 308 309 310 311
            cmd += " --use_cuda"
            env_local = {"CUDA_VISIBLE_DEVICES": "0"}
        else:
            env_local = {'CPU_NUM': '1'}

W
Wu Yi 已提交
312 313
        env_local.update(envs)
        print("local_cmd: {}, env: {}".format(cmd, env_local))
G
gongweibao 已提交
314

315
        if check_error_log:
316
            err_log = open("/tmp/trainer.err.log", "wb")
G
gongweibao 已提交
317
            local_proc = subprocess.Popen(
318
                cmd.split(" "),
G
gongweibao 已提交
319
                stdout=subprocess.PIPE,
320
                stderr=err_log,
W
Wu Yi 已提交
321
                env=env_local)
G
gongweibao 已提交
322 323
        else:
            local_proc = subprocess.Popen(
324
                cmd.split(" "),
G
gongweibao 已提交
325
                stdout=subprocess.PIPE,
326
                stderr=subprocess.PIPE,
W
Wu Yi 已提交
327
                env=env_local)
G
gongweibao 已提交
328

329 330 331 332 333 334
        local_out, local_err = local_proc.communicate()

        if check_error_log:
            err_log.close()

        sys.stderr.write('local_stderr: %s\n' % local_err)
W
Wu Yi 已提交
335
        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
X
Xin Pan 已提交
336

W
Wu Yi 已提交
337
        return pickle.loads(local_out)
338 339

    def _run_cluster(self, model, envs, check_error_log):
X
Xin Pan 已提交
340
        # Run dist train to compare with local results
341 342
        ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
                                                          check_error_log, envs)
W
Wu Yi 已提交
343

X
Xin Pan 已提交
344
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
345

W
Wu Yi 已提交
346
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f"
W
Wu Yi 已提交
347
        tr0_cmd = tr_cmd % \
348
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
349
                   0, ps0_ep, self._trainers, self._lr)
W
Wu Yi 已提交
350
        tr1_cmd = tr_cmd % \
351
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
352
                   1, ps1_ep, self._trainers, self._lr)
W
Wu Yi 已提交
353 354 355 356 357 358 359 360 361 362

        if self._sync_mode:
            tr0_cmd += " --sync_mode"
            tr1_cmd += " --sync_mode"
        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
363 364 365
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
366
        if self.__use_cuda:
367 368 369 370 371 372 373 374 375 376
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)
X
Xin Pan 已提交
377

W
Wu Yi 已提交
378 379
        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
380 381
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")
G
gongweibao 已提交
382

X
Xin Pan 已提交
383
        tr0_proc = subprocess.Popen(
W
Wu Yi 已提交
384
            tr0_cmd.strip().split(" "),
X
Xin Pan 已提交
385
            stdout=subprocess.PIPE,
G
gongweibao 已提交
386
            stderr=tr0_pipe,
X
Xin Pan 已提交
387 388
            env=env0)
        tr1_proc = subprocess.Popen(
W
Wu Yi 已提交
389
            tr1_cmd.strip().split(" "),
X
Xin Pan 已提交
390
            stdout=subprocess.PIPE,
G
gongweibao 已提交
391
            stderr=tr1_pipe,
X
Xin Pan 已提交
392 393
            env=env1)

394 395 396 397 398 399 400 401 402 403 404 405
        # Wait until trainer process terminate
        while True:
            stat0 = tr0_proc.poll()
            time.sleep(0.1)
            if stat0 is not None:
                break
        while True:
            stat1 = tr1_proc.poll()
            time.sleep(0.1)
            if stat1 is not None:
                break

406 407
        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
X
Xin Pan 已提交
408

G
gongweibao 已提交
409
        # close trainer file
410 411 412 413
        tr0_pipe.close()
        tr1_pipe.close()
        ps0_pipe.close()
        ps1_pipe.close()
W
Wu Yi 已提交
414

W
Wu Yi 已提交
415 416
        ps0.terminate()
        ps1.terminate()
T
typhoonzero 已提交
417

418 419 420 421 422 423
        # print server log
        with open("/tmp/ps0_err.log", "r") as fn:
            sys.stderr.write("ps0 stderr: %s\n" % fn.read())
        with open("/tmp/ps1_err.log", "r") as fn:
            sys.stderr.write("ps1 stderr: %s\n" % fn.read())

424
        # print log
425 426 427 428 429 430 431 432
        if stat0 == 0:
            sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
        with open("/tmp/tr0_err.log", "r") as fn:
            sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
        if stat1 == 0:
            sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
        with open("/tmp/tr1_err.log", "r") as fn:
            sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
433

W
Wu Yi 已提交
434 435 436 437 438 439 440
        return pickle.loads(tr0_out), pickle.loads(tr1_out)

    def _run_cluster_nccl2(self, model, envs, check_error_log):
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints

W
Wu Yi 已提交
441
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
W
Wu Yi 已提交
442 443
        tr0_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
444
                   0, w0_ep, self._lr)
W
Wu Yi 已提交
445 446
        tr1_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
447
                   1, w1_ep, self._lr)
W
Wu Yi 已提交
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498

        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
        if self.__use_cuda:
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)

        print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")

        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0)
        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1)

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()

        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()

        # print log
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
        sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)

W
Wu Yi 已提交
499
        return pickle.loads(tr0_out), pickle.loads(tr1_out)
500 501 502 503 504 505 506 507 508 509 510 511

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={}):
        # TODO(typhoonzero): should auto adapt GPU count on the machine.
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
512
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
513
            "FLAGS_cudnn_deterministic": "1",
W
Wu Yi 已提交
514 515
            "http_proxy": "",
            "NCCL_P2P_DISABLE": "1"
516 517 518 519 520
        }

        required_envs.update(need_envs)

        if check_error_log:
W
Wu Yi 已提交
521
            required_envs["GLOG_v"] = "3"
522 523 524 525 526
            required_envs["GLOG_logtostderr"] = "1"

        local_losses\
            = self._run_local(model_file, required_envs,
                                       check_error_log)
W
Wu Yi 已提交
527 528 529 530 531 532
        if self._nccl2_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file, required_envs, check_error_log)
        else:
            tr0_losses, tr1_losses = self._run_cluster(
                model_file, required_envs, check_error_log)
533 534

        for step_id in range(RUN_STEP):
W
Wu Yi 已提交
535 536 537 538 539 540
            local_loss = local_losses[step_id]
            tr0_loss = tr0_losses[step_id]
            tr1_loss = tr1_losses[step_id]
            dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
            print("=======", local_loss, ":", dist_loss[0], "=======")
            self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)