test_dist_base.py 19.5 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
X
Xin Pan 已提交
16 17 18 19 20 21 22
import time

import unittest
import os
import sys
import signal
import subprocess
23
import six
W
Wu Yi 已提交
24
import argparse
W
Wu Yi 已提交
25 26
import pickle
import numpy as np
T
typhoonzero 已提交
27

28
import paddle.fluid as fluid
29
from paddle.fluid import compiler
30 31

RUN_STEP = 10
32
DEFAULT_BATCH_SIZE = 2
33

T
typhoonzero 已提交
34 35

class TestDistRunnerBase(object):
W
Wu Yi 已提交
36
    def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1):
T
typhoonzero 已提交
37 38 39
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

40
    @staticmethod
W
Wu Yi 已提交
41 42 43 44 45
    def get_transpiler(trainer_id,
                       main_program,
                       pserver_endpoints,
                       trainers,
                       sync_mode,
46 47
                       dc_asgd=False,
                       current_endpoint=None):
T
typhoonzero 已提交
48
        # NOTE: import fluid until runtime, or else forking processes will cause error.
49
        config = fluid.DistributeTranspilerConfig()
W
Wu Yi 已提交
50
        config.enable_dc_asgd = dc_asgd
Q
Qiao Longfei 已提交
51
        config.runtime_split_send_recv = True
52
        t = fluid.DistributeTranspiler(config=config)
T
typhoonzero 已提交
53 54 55 56
        t.transpile(
            trainer_id=trainer_id,
            program=main_program,
            pservers=pserver_endpoints,
W
Wu Yi 已提交
57
            trainers=trainers,
58 59
            sync_mode=sync_mode,
            current_endpoint=current_endpoint)
T
typhoonzero 已提交
60 61
        return t

W
Wu Yi 已提交
62
    def run_pserver(self, args):
W
Wu Yi 已提交
63
        self.lr = args.lr
64
        self.get_model(batch_size=args.batch_size)
65
        # NOTE: pserver should not call memory optimize
W
Wu Yi 已提交
66 67
        t = self.get_transpiler(args.trainer_id,
                                fluid.default_main_program(), args.endpoints,
W
Wu Yi 已提交
68
                                args.trainers, args.sync_mode, args.dc_asgd)
W
Wu Yi 已提交
69 70 71
        pserver_prog = t.get_pserver_program(args.current_endpoint)
        startup_prog = t.get_startup_program(args.current_endpoint,
                                             pserver_prog)
Y
Yancey1989 已提交
72

T
typhoonzero 已提交
73 74 75 76 77
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        exe.run(pserver_prog)

78
    def run_trainer(self, args):
W
Wu Yi 已提交
79
        self.lr = args.lr
T
typhoonzero 已提交
80
        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
81
            self.get_model(batch_size=args.batch_size)
82

W
Wu Yi 已提交
83
        if args.mem_opt:
84
            fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
W
Wu Yi 已提交
85
        if args.update_method == "pserver":
W
Wu Yi 已提交
86 87 88
            t = self.get_transpiler(args.trainer_id,
                                    fluid.default_main_program(),
                                    args.endpoints, args.trainers,
W
Wu Yi 已提交
89
                                    args.sync_mode, args.dc_asgd)
T
typhoonzero 已提交
90
            trainer_prog = t.get_trainer_program()
Q
Qiao Longfei 已提交
91 92 93
            with open("/tmp/trainer." + str(args.trainer_id) + ".proto",
                      "w") as f:
                f.write(str(trainer_prog))
W
Wu Yi 已提交
94 95 96 97 98 99 100 101 102 103 104 105
        elif args.update_method == "nccl2":
            # transpile for nccl2
            config = fluid.DistributeTranspilerConfig()
            config.mode = "nccl2"
            nccl2_t = fluid.DistributeTranspiler(config=config)
            nccl2_t.transpile(
                args.trainer_id,
                program=fluid.default_main_program(),
                startup_program=fluid.default_startup_program(),
                trainers=args.endpoints,
                current_endpoint=args.current_endpoint)
            trainer_prog = fluid.default_main_program()
T
typhoonzero 已提交
106 107 108
        else:
            trainer_prog = fluid.default_main_program()

109 110 111 112 113
        if args.use_cuda:
            place = fluid.CUDAPlace(0)
        else:
            place = fluid.CPUPlace()

114 115
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
T
typhoonzero 已提交
116 117 118 119

        strategy = fluid.ExecutionStrategy()
        strategy.num_threads = 1
        strategy.allow_op_delay = False
120

W
Wu Yi 已提交
121
        build_stra = fluid.BuildStrategy()
Q
Qiao Longfei 已提交
122
        build_stra.debug_graphviz_path = "/tmp/graph-" + str(args.trainer_id)
W
Wu Yi 已提交
123 124 125 126 127 128

        if args.use_reduce:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
        else:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce

X
Xin Pan 已提交
129
        if args.batch_merge_repeat > 1:
X
fix  
Xin Pan 已提交
130
            pass_builder = build_stra._finalize_strategy_and_create_passes()
X
Xin Pan 已提交
131
            mypass = pass_builder.insert_pass(
S
sneaxiy 已提交
132
                len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
133
            mypass.set("num_repeats", args.batch_merge_repeat)
X
Xin Pan 已提交
134

W
Wu Yi 已提交
135
        if args.update_method == "nccl2":
136 137
            build_stra.num_trainers = len(args.endpoints.split(","))
            build_stra.trainer_id = args.trainer_id
W
Wu Yi 已提交
138
        else:
139 140
            build_stra.num_trainers = 1
            build_stra.trainer_id = 0
W
Wu Yi 已提交
141

X
Xin Pan 已提交
142
        binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
W
Wu Yi 已提交
143
            loss_name=avg_cost.name,
W
Wu Yi 已提交
144
            build_strategy=build_stra,
145
            exec_strategy=strategy)
T
typhoonzero 已提交
146 147 148 149 150 151 152

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

        feeder = fluid.DataFeeder(feed_var_list, place)
153
        reader_generator = train_reader()
T
typhoonzero 已提交
154

155 156
        def get_data():
            origin_batch = next(reader_generator)
W
Wu Yi 已提交
157
            if args.update_method != "local" and args.use_reader_alloc:
158 159 160 161 162 163 164
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch
T
typhoonzero 已提交
165

W
Wu Yi 已提交
166
        out_losses = []
167
        for _ in six.moves.xrange(RUN_STEP):
168 169
            loss, = exe.run(binary,
                            fetch_list=[avg_cost.name],
170
                            feed=feeder.feed(get_data()))
W
Wu Yi 已提交
171 172 173 174 175
            out_losses.append(loss[0])
        if six.PY2:
            print(pickle.dumps(out_losses))
        else:
            sys.stdout.buffer.write(pickle.dumps(out_losses))
T
typhoonzero 已提交
176 177 178


def runtime_main(test_class):
W
Wu Yi 已提交
179 180 181 182
    parser = argparse.ArgumentParser(description='Run dist test.')
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer'])
    parser.add_argument('--endpoints', type=str, required=False, default="")
W
Wu Yi 已提交
183 184 185 186 187
    parser.add_argument(
        '--update_method',
        type=str,
        default="local",
        choices=["pserver", "nccl2", "local"])
W
Wu Yi 已提交
188 189 190 191 192 193
    parser.add_argument('--trainer_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
    parser.add_argument(
        '--current_endpoint', type=str, required=False, default="")
    parser.add_argument('--sync_mode', action='store_true')
    parser.add_argument('--mem_opt', action='store_true')
194
    parser.add_argument('--use_cuda', action='store_true')
W
Wu Yi 已提交
195
    parser.add_argument('--use_reduce', action='store_true')
W
Wu Yi 已提交
196
    parser.add_argument('--dc_asgd', action='store_true')
197
    parser.add_argument(
W
Wu Yi 已提交
198
        '--use_reader_alloc', action='store_true', required=False)
199
    parser.add_argument('--batch_size', required=False, type=int, default=2)
W
Wu Yi 已提交
200
    parser.add_argument('--lr', required=False, type=float, default=0.001)
201 202
    parser.add_argument(
        '--batch_merge_repeat', required=False, type=int, default=1)
W
Wu Yi 已提交
203 204

    args = parser.parse_args()
T
typhoonzero 已提交
205 206

    model = test_class()
W
Wu Yi 已提交
207
    if args.role == "pserver" and args.update_method == "pserver":
W
Wu Yi 已提交
208
        model.run_pserver(args)
T
typhoonzero 已提交
209
    else:
210
        model.run_trainer(args)
X
Xin Pan 已提交
211

M
minqiyang 已提交
212

M
minqiyang 已提交
213
import paddle.compat as cpt
Y
Yancey1989 已提交
214 215
import socket
from contextlib import closing
M
minqiyang 已提交
216

X
Xin Pan 已提交
217 218

class TestDistBase(unittest.TestCase):
W
Wu Yi 已提交
219 220 221
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

222 223 224 225 226 227 228 229 230 231 232
    def _after_setup_config(self):
        if self._enforce_place == "CPU":
            self.__use_cuda = False
        elif self._enforce_place == "GPU":
            self.__use_cuda = True
        else:
            if fluid.core.is_compiled_with_cuda():
                self.__use_cuda = True
            else:
                self.__use_cuda = False

X
Xin Pan 已提交
233 234 235
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
Y
Yancey1989 已提交
236
        self._port_set = set()
Y
Yancey1989 已提交
237 238
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
            self._find_free_port(), self._find_free_port())
M
minqiyang 已提交
239
        self._python_interp = sys.executable
W
Wu Yi 已提交
240
        self._sync_mode = True
241
        self._enforce_place = None
W
Wu Yi 已提交
242
        self._mem_opt = False
W
Wu Yi 已提交
243
        self._use_reduce = False
W
Wu Yi 已提交
244
        self._dc_asgd = False  # must use with async mode
245
        self._use_reader_alloc = True
W
Wu Yi 已提交
246
        self._nccl2_mode = False
W
Wu Yi 已提交
247
        self._lr = 0.001
W
Wu Yi 已提交
248
        self._setup_config()
249
        self._after_setup_config()
X
Xin Pan 已提交
250

Y
Yancey1989 已提交
251
    def _find_free_port(self):
Y
Yancey1989 已提交
252 253 254 255 256 257 258 259 260 261 262
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port
Y
Yancey1989 已提交
263

264
    def start_pserver(self, model_file, check_error_log, required_envs):
X
Xin Pan 已提交
265
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
W
Wu Yi 已提交
266
        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
W
Wu Yi 已提交
267
        ps0_cmd = ps_cmd % \
268 269
                  (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                   self._trainers)
W
Wu Yi 已提交
270
        ps1_cmd = ps_cmd % \
271 272
                  (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
                   self._trainers)
W
Wu Yi 已提交
273 274 275 276 277 278 279

        if self._sync_mode:
            ps0_cmd += " --sync_mode"
            ps1_cmd += " --sync_mode"
        if self._mem_opt:
            ps0_cmd += " --mem_opt"
            ps1_cmd += " --mem_opt"
X
Xin Pan 已提交
280

281 282
        print(ps0_cmd)
        print(ps1_cmd)
M
minqiyang 已提交
283 284
        ps0_pipe = open("/tmp/ps0_err.log", "wb")
        ps1_pipe = open("/tmp/ps1_err.log", "wb")
G
gongweibao 已提交
285

X
Xin Pan 已提交
286
        ps0_proc = subprocess.Popen(
287 288 289 290
            ps0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps0_pipe,
            env=required_envs)
X
Xin Pan 已提交
291
        ps1_proc = subprocess.Popen(
292 293 294 295
            ps1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps1_pipe,
            env=required_envs)
G
gongweibao 已提交
296

297
        return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
X
Xin Pan 已提交
298

299 300 301 302 303 304
    def _run_local(self,
                   model,
                   envs,
                   check_error_log=False,
                   batch_size=DEFAULT_BATCH_SIZE,
                   batch_merge_repeat=1):
G
gongweibao 已提交
305

W
Wu Yi 已提交
306 307
        cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model,
                                                self._lr)
308 309 310 311
        if batch_size != DEFAULT_BATCH_SIZE:
            cmd += " --batch_size %d" % batch_size
        if batch_merge_repeat > 1:
            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
312

313
        if self.__use_cuda:
314 315 316 317 318
            cmd += " --use_cuda"
            env_local = {"CUDA_VISIBLE_DEVICES": "0"}
        else:
            env_local = {'CPU_NUM': '1'}

W
Wu Yi 已提交
319 320
        env_local.update(envs)
        print("local_cmd: {}, env: {}".format(cmd, env_local))
G
gongweibao 已提交
321

322
        if check_error_log:
323
            err_log = open("/tmp/trainer.err.log", "wb")
G
gongweibao 已提交
324
            local_proc = subprocess.Popen(
325
                cmd.split(" "),
G
gongweibao 已提交
326
                stdout=subprocess.PIPE,
327
                stderr=err_log,
W
Wu Yi 已提交
328
                env=env_local)
G
gongweibao 已提交
329 330
        else:
            local_proc = subprocess.Popen(
331
                cmd.split(" "),
G
gongweibao 已提交
332
                stdout=subprocess.PIPE,
333
                stderr=subprocess.PIPE,
W
Wu Yi 已提交
334
                env=env_local)
G
gongweibao 已提交
335

336 337 338 339 340 341
        local_out, local_err = local_proc.communicate()

        if check_error_log:
            err_log.close()

        sys.stderr.write('local_stderr: %s\n' % local_err)
W
Wu Yi 已提交
342
        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
X
Xin Pan 已提交
343

W
Wu Yi 已提交
344
        return pickle.loads(local_out)
345 346

    def _run_cluster(self, model, envs, check_error_log):
X
Xin Pan 已提交
347
        # Run dist train to compare with local results
348 349
        ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
                                                          check_error_log, envs)
W
Wu Yi 已提交
350

X
Xin Pan 已提交
351
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
352

W
Wu Yi 已提交
353
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f"
W
Wu Yi 已提交
354
        tr0_cmd = tr_cmd % \
355
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
356
                   0, ps0_ep, self._trainers, self._lr)
W
Wu Yi 已提交
357
        tr1_cmd = tr_cmd % \
358
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
359
                   1, ps1_ep, self._trainers, self._lr)
W
Wu Yi 已提交
360 361 362 363 364 365 366 367 368 369

        if self._sync_mode:
            tr0_cmd += " --sync_mode"
            tr1_cmd += " --sync_mode"
        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
370 371 372
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
373
        if self.__use_cuda:
374 375 376 377 378 379 380 381 382 383
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)
X
Xin Pan 已提交
384

W
Wu Yi 已提交
385 386
        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
387 388
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")
G
gongweibao 已提交
389

X
Xin Pan 已提交
390
        tr0_proc = subprocess.Popen(
W
Wu Yi 已提交
391
            tr0_cmd.strip().split(" "),
X
Xin Pan 已提交
392
            stdout=subprocess.PIPE,
G
gongweibao 已提交
393
            stderr=tr0_pipe,
X
Xin Pan 已提交
394 395
            env=env0)
        tr1_proc = subprocess.Popen(
W
Wu Yi 已提交
396
            tr1_cmd.strip().split(" "),
X
Xin Pan 已提交
397
            stdout=subprocess.PIPE,
G
gongweibao 已提交
398
            stderr=tr1_pipe,
X
Xin Pan 已提交
399 400
            env=env1)

401 402 403 404 405 406 407 408 409 410 411 412
        # Wait until trainer process terminate
        while True:
            stat0 = tr0_proc.poll()
            time.sleep(0.1)
            if stat0 is not None:
                break
        while True:
            stat1 = tr1_proc.poll()
            time.sleep(0.1)
            if stat1 is not None:
                break

413 414
        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
X
Xin Pan 已提交
415

G
gongweibao 已提交
416
        # close trainer file
417 418 419 420
        tr0_pipe.close()
        tr1_pipe.close()
        ps0_pipe.close()
        ps1_pipe.close()
W
Wu Yi 已提交
421

W
Wu Yi 已提交
422 423
        ps0.terminate()
        ps1.terminate()
T
typhoonzero 已提交
424

425 426 427 428 429 430
        # print server log
        with open("/tmp/ps0_err.log", "r") as fn:
            sys.stderr.write("ps0 stderr: %s\n" % fn.read())
        with open("/tmp/ps1_err.log", "r") as fn:
            sys.stderr.write("ps1 stderr: %s\n" % fn.read())

431
        # print log
432 433 434 435 436 437 438 439
        if stat0 == 0:
            sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
        with open("/tmp/tr0_err.log", "r") as fn:
            sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
        if stat1 == 0:
            sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
        with open("/tmp/tr1_err.log", "r") as fn:
            sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
440

W
Wu Yi 已提交
441 442 443 444 445 446 447
        return pickle.loads(tr0_out), pickle.loads(tr1_out)

    def _run_cluster_nccl2(self, model, envs, check_error_log):
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints

W
Wu Yi 已提交
448
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
W
Wu Yi 已提交
449 450
        tr0_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
451
                   0, w0_ep, self._lr)
W
Wu Yi 已提交
452 453
        tr1_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
454
                   1, w1_ep, self._lr)
W
Wu Yi 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505

        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
        if self.__use_cuda:
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)

        print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")

        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0)
        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1)

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()

        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()

        # print log
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
        sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)

W
Wu Yi 已提交
506
        return pickle.loads(tr0_out), pickle.loads(tr1_out)
507 508 509 510 511 512 513 514 515 516 517 518

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={}):
        # TODO(typhoonzero): should auto adapt GPU count on the machine.
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
519
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
520
            "FLAGS_cudnn_deterministic": "1",
W
Wu Yi 已提交
521 522
            "http_proxy": "",
            "NCCL_P2P_DISABLE": "1"
523 524 525 526 527
        }

        required_envs.update(need_envs)

        if check_error_log:
W
Wu Yi 已提交
528
            required_envs["GLOG_v"] = "3"
529 530 531 532 533
            required_envs["GLOG_logtostderr"] = "1"

        local_losses\
            = self._run_local(model_file, required_envs,
                                       check_error_log)
W
Wu Yi 已提交
534 535 536 537 538 539
        if self._nccl2_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file, required_envs, check_error_log)
        else:
            tr0_losses, tr1_losses = self._run_cluster(
                model_file, required_envs, check_error_log)
540 541

        for step_id in range(RUN_STEP):
W
Wu Yi 已提交
542 543 544 545 546 547
            local_loss = local_losses[step_id]
            tr0_loss = tr0_losses[step_id]
            tr1_loss = tr1_losses[step_id]
            dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
            print("=======", local_loss, ":", dist_loss[0], "=======")
            self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)