test_dist_base.py 19.5 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
X
Xin Pan 已提交
16 17 18 19 20 21 22
import time

import unittest
import os
import sys
import signal
import subprocess
23
import six
W
Wu Yi 已提交
24
import argparse
W
Wu Yi 已提交
25 26
import pickle
import numpy as np
T
typhoonzero 已提交
27

28
import paddle.fluid as fluid
29
from paddle.fluid import compiler
30 31

RUN_STEP = 10
32
DEFAULT_BATCH_SIZE = 2
33

T
typhoonzero 已提交
34 35

class TestDistRunnerBase(object):
W
Wu Yi 已提交
36
    def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1):
T
typhoonzero 已提交
37 38 39
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

40
    @staticmethod
W
Wu Yi 已提交
41 42 43 44 45
    def get_transpiler(trainer_id,
                       main_program,
                       pserver_endpoints,
                       trainers,
                       sync_mode,
46 47
                       dc_asgd=False,
                       current_endpoint=None):
T
typhoonzero 已提交
48
        # NOTE: import fluid until runtime, or else forking processes will cause error.
49
        config = fluid.DistributeTranspilerConfig()
W
Wu Yi 已提交
50
        config.enable_dc_asgd = dc_asgd
Q
Qiao Longfei 已提交
51
        config.runtime_split_send_recv = True
52
        t = fluid.DistributeTranspiler(config=config)
T
typhoonzero 已提交
53 54 55 56
        t.transpile(
            trainer_id=trainer_id,
            program=main_program,
            pservers=pserver_endpoints,
W
Wu Yi 已提交
57
            trainers=trainers,
58 59
            sync_mode=sync_mode,
            current_endpoint=current_endpoint)
T
typhoonzero 已提交
60 61
        return t

W
Wu Yi 已提交
62
    def run_pserver(self, args):
W
Wu Yi 已提交
63
        self.lr = args.lr
64
        self.get_model(batch_size=args.batch_size)
65
        # NOTE: pserver should not call memory optimize
W
Wu Yi 已提交
66 67
        t = self.get_transpiler(args.trainer_id,
                                fluid.default_main_program(), args.endpoints,
W
Wu Yi 已提交
68
                                args.trainers, args.sync_mode, args.dc_asgd)
W
Wu Yi 已提交
69 70 71
        pserver_prog = t.get_pserver_program(args.current_endpoint)
        startup_prog = t.get_startup_program(args.current_endpoint,
                                             pserver_prog)
Y
Yancey1989 已提交
72

T
typhoonzero 已提交
73 74 75 76 77
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        exe.run(pserver_prog)

78
    def run_trainer(self, args):
W
Wu Yi 已提交
79
        self.lr = args.lr
T
typhoonzero 已提交
80
        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
81
            self.get_model(batch_size=args.batch_size)
82

W
Wu Yi 已提交
83
        if args.mem_opt:
84
            fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
W
Wu Yi 已提交
85
        if args.update_method == "pserver":
W
Wu Yi 已提交
86 87 88
            t = self.get_transpiler(args.trainer_id,
                                    fluid.default_main_program(),
                                    args.endpoints, args.trainers,
W
Wu Yi 已提交
89
                                    args.sync_mode, args.dc_asgd)
T
typhoonzero 已提交
90
            trainer_prog = t.get_trainer_program()
Q
Qiao Longfei 已提交
91 92 93
            with open("/tmp/trainer." + str(args.trainer_id) + ".proto",
                      "w") as f:
                f.write(str(trainer_prog))
W
Wu Yi 已提交
94 95 96 97 98 99 100 101 102 103 104 105
        elif args.update_method == "nccl2":
            # transpile for nccl2
            config = fluid.DistributeTranspilerConfig()
            config.mode = "nccl2"
            nccl2_t = fluid.DistributeTranspiler(config=config)
            nccl2_t.transpile(
                args.trainer_id,
                program=fluid.default_main_program(),
                startup_program=fluid.default_startup_program(),
                trainers=args.endpoints,
                current_endpoint=args.current_endpoint)
            trainer_prog = fluid.default_main_program()
T
typhoonzero 已提交
106 107 108
        else:
            trainer_prog = fluid.default_main_program()

109 110 111 112 113
        if args.use_cuda:
            place = fluid.CUDAPlace(0)
        else:
            place = fluid.CPUPlace()

114 115
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
T
typhoonzero 已提交
116 117 118 119

        strategy = fluid.ExecutionStrategy()
        strategy.num_threads = 1
        strategy.allow_op_delay = False
120

W
Wu Yi 已提交
121
        build_stra = fluid.BuildStrategy()
Q
Qiao Longfei 已提交
122
        build_stra.debug_graphviz_path = "/tmp/graph-" + str(args.trainer_id)
W
Wu Yi 已提交
123 124 125 126 127 128

        if args.use_reduce:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
        else:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce

X
Xin Pan 已提交
129
        if args.batch_merge_repeat > 1:
X
fix  
Xin Pan 已提交
130
            pass_builder = build_stra._finalize_strategy_and_create_passes()
131
            mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass")
132
            mypass.set("num_repeats", args.batch_merge_repeat)
X
Xin Pan 已提交
133

W
Wu Yi 已提交
134
        if args.update_method == "nccl2":
135 136
            build_stra.num_trainers = len(args.endpoints.split(","))
            build_stra.trainer_id = args.trainer_id
W
Wu Yi 已提交
137
        else:
138 139
            build_stra.num_trainers = 1
            build_stra.trainer_id = 0
W
Wu Yi 已提交
140

X
Xin Pan 已提交
141
        binary = compiler.CompiledProgram(trainer_prog).with_data_parallel(
W
Wu Yi 已提交
142
            loss_name=avg_cost.name,
W
Wu Yi 已提交
143
            build_strategy=build_stra,
144
            exec_strategy=strategy)
T
typhoonzero 已提交
145 146 147 148 149 150 151

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

        feeder = fluid.DataFeeder(feed_var_list, place)
152
        reader_generator = train_reader()
T
typhoonzero 已提交
153

154 155
        def get_data():
            origin_batch = next(reader_generator)
W
Wu Yi 已提交
156
            if args.update_method != "local" and args.use_reader_alloc:
157 158 159 160 161 162 163
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch
T
typhoonzero 已提交
164

W
Wu Yi 已提交
165
        out_losses = []
166
        for _ in six.moves.xrange(RUN_STEP):
167 168
            loss, = exe.run(binary,
                            fetch_list=[avg_cost.name],
169
                            feed=feeder.feed(get_data()))
W
Wu Yi 已提交
170 171 172 173 174
            out_losses.append(loss[0])
        if six.PY2:
            print(pickle.dumps(out_losses))
        else:
            sys.stdout.buffer.write(pickle.dumps(out_losses))
T
typhoonzero 已提交
175 176 177


def runtime_main(test_class):
W
Wu Yi 已提交
178 179 180 181
    parser = argparse.ArgumentParser(description='Run dist test.')
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer'])
    parser.add_argument('--endpoints', type=str, required=False, default="")
W
Wu Yi 已提交
182 183 184 185 186
    parser.add_argument(
        '--update_method',
        type=str,
        default="local",
        choices=["pserver", "nccl2", "local"])
W
Wu Yi 已提交
187 188 189 190 191 192
    parser.add_argument('--trainer_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
    parser.add_argument(
        '--current_endpoint', type=str, required=False, default="")
    parser.add_argument('--sync_mode', action='store_true')
    parser.add_argument('--mem_opt', action='store_true')
193
    parser.add_argument('--use_cuda', action='store_true')
W
Wu Yi 已提交
194
    parser.add_argument('--use_reduce', action='store_true')
W
Wu Yi 已提交
195
    parser.add_argument('--dc_asgd', action='store_true')
196
    parser.add_argument(
W
Wu Yi 已提交
197
        '--use_reader_alloc', action='store_true', required=False)
198
    parser.add_argument('--batch_size', required=False, type=int, default=2)
W
Wu Yi 已提交
199
    parser.add_argument('--lr', required=False, type=float, default=0.001)
200 201
    parser.add_argument(
        '--batch_merge_repeat', required=False, type=int, default=1)
W
Wu Yi 已提交
202 203

    args = parser.parse_args()
T
typhoonzero 已提交
204 205

    model = test_class()
W
Wu Yi 已提交
206
    if args.role == "pserver" and args.update_method == "pserver":
W
Wu Yi 已提交
207
        model.run_pserver(args)
T
typhoonzero 已提交
208
    else:
209
        model.run_trainer(args)
X
Xin Pan 已提交
210

M
minqiyang 已提交
211

M
minqiyang 已提交
212
import paddle.compat as cpt
Y
Yancey1989 已提交
213 214
import socket
from contextlib import closing
M
minqiyang 已提交
215

X
Xin Pan 已提交
216 217

class TestDistBase(unittest.TestCase):
W
Wu Yi 已提交
218 219 220
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

221 222 223 224 225 226 227 228 229 230 231
    def _after_setup_config(self):
        if self._enforce_place == "CPU":
            self.__use_cuda = False
        elif self._enforce_place == "GPU":
            self.__use_cuda = True
        else:
            if fluid.core.is_compiled_with_cuda():
                self.__use_cuda = True
            else:
                self.__use_cuda = False

X
Xin Pan 已提交
232 233 234
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
Y
Yancey1989 已提交
235
        self._port_set = set()
Y
Yancey1989 已提交
236 237
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
            self._find_free_port(), self._find_free_port())
M
minqiyang 已提交
238
        self._python_interp = sys.executable
W
Wu Yi 已提交
239
        self._sync_mode = True
240
        self._enforce_place = None
W
Wu Yi 已提交
241
        self._mem_opt = False
W
Wu Yi 已提交
242
        self._use_reduce = False
W
Wu Yi 已提交
243
        self._dc_asgd = False  # must use with async mode
244
        self._use_reader_alloc = True
W
Wu Yi 已提交
245
        self._nccl2_mode = False
W
Wu Yi 已提交
246
        self._lr = 0.001
W
Wu Yi 已提交
247
        self._setup_config()
248
        self._after_setup_config()
X
Xin Pan 已提交
249

Y
Yancey1989 已提交
250
    def _find_free_port(self):
Y
Yancey1989 已提交
251 252 253 254 255 256 257 258 259 260 261
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port
Y
Yancey1989 已提交
262

263
    def start_pserver(self, model_file, check_error_log, required_envs):
X
Xin Pan 已提交
264
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
W
Wu Yi 已提交
265
        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
W
Wu Yi 已提交
266
        ps0_cmd = ps_cmd % \
267 268
                  (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                   self._trainers)
W
Wu Yi 已提交
269
        ps1_cmd = ps_cmd % \
270 271
                  (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
                   self._trainers)
W
Wu Yi 已提交
272 273 274 275 276 277 278

        if self._sync_mode:
            ps0_cmd += " --sync_mode"
            ps1_cmd += " --sync_mode"
        if self._mem_opt:
            ps0_cmd += " --mem_opt"
            ps1_cmd += " --mem_opt"
X
Xin Pan 已提交
279

280 281
        print(ps0_cmd)
        print(ps1_cmd)
M
minqiyang 已提交
282 283
        ps0_pipe = open("/tmp/ps0_err.log", "wb")
        ps1_pipe = open("/tmp/ps1_err.log", "wb")
G
gongweibao 已提交
284

X
Xin Pan 已提交
285
        ps0_proc = subprocess.Popen(
286 287 288 289
            ps0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps0_pipe,
            env=required_envs)
X
Xin Pan 已提交
290
        ps1_proc = subprocess.Popen(
291 292 293 294
            ps1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps1_pipe,
            env=required_envs)
G
gongweibao 已提交
295

296
        return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
X
Xin Pan 已提交
297

298 299 300 301 302 303
    def _run_local(self,
                   model,
                   envs,
                   check_error_log=False,
                   batch_size=DEFAULT_BATCH_SIZE,
                   batch_merge_repeat=1):
G
gongweibao 已提交
304

W
Wu Yi 已提交
305 306
        cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model,
                                                self._lr)
307 308 309 310
        if batch_size != DEFAULT_BATCH_SIZE:
            cmd += " --batch_size %d" % batch_size
        if batch_merge_repeat > 1:
            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
311

312
        if self.__use_cuda:
313 314 315 316 317
            cmd += " --use_cuda"
            env_local = {"CUDA_VISIBLE_DEVICES": "0"}
        else:
            env_local = {'CPU_NUM': '1'}

W
Wu Yi 已提交
318 319
        env_local.update(envs)
        print("local_cmd: {}, env: {}".format(cmd, env_local))
G
gongweibao 已提交
320

321
        if check_error_log:
322
            err_log = open("/tmp/trainer.err.log", "wb")
G
gongweibao 已提交
323
            local_proc = subprocess.Popen(
324
                cmd.split(" "),
G
gongweibao 已提交
325
                stdout=subprocess.PIPE,
326
                stderr=err_log,
W
Wu Yi 已提交
327
                env=env_local)
G
gongweibao 已提交
328 329
        else:
            local_proc = subprocess.Popen(
330
                cmd.split(" "),
G
gongweibao 已提交
331
                stdout=subprocess.PIPE,
332
                stderr=subprocess.PIPE,
W
Wu Yi 已提交
333
                env=env_local)
G
gongweibao 已提交
334

335 336 337 338 339 340
        local_out, local_err = local_proc.communicate()

        if check_error_log:
            err_log.close()

        sys.stderr.write('local_stderr: %s\n' % local_err)
W
Wu Yi 已提交
341
        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
X
Xin Pan 已提交
342

W
Wu Yi 已提交
343
        return pickle.loads(local_out)
344 345

    def _run_cluster(self, model, envs, check_error_log):
X
Xin Pan 已提交
346
        # Run dist train to compare with local results
347 348
        ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
                                                          check_error_log, envs)
W
Wu Yi 已提交
349

X
Xin Pan 已提交
350
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
351

W
Wu Yi 已提交
352
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f"
W
Wu Yi 已提交
353
        tr0_cmd = tr_cmd % \
354
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
355
                   0, ps0_ep, self._trainers, self._lr)
W
Wu Yi 已提交
356
        tr1_cmd = tr_cmd % \
357
                  (self._python_interp, model, self._ps_endpoints,
W
Wu Yi 已提交
358
                   1, ps1_ep, self._trainers, self._lr)
W
Wu Yi 已提交
359 360 361 362 363 364 365 366 367 368

        if self._sync_mode:
            tr0_cmd += " --sync_mode"
            tr1_cmd += " --sync_mode"
        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
369 370 371
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
372
        if self.__use_cuda:
373 374 375 376 377 378 379 380 381 382
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)
X
Xin Pan 已提交
383

W
Wu Yi 已提交
384 385
        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
386 387
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")
G
gongweibao 已提交
388

X
Xin Pan 已提交
389
        tr0_proc = subprocess.Popen(
W
Wu Yi 已提交
390
            tr0_cmd.strip().split(" "),
X
Xin Pan 已提交
391
            stdout=subprocess.PIPE,
G
gongweibao 已提交
392
            stderr=tr0_pipe,
X
Xin Pan 已提交
393 394
            env=env0)
        tr1_proc = subprocess.Popen(
W
Wu Yi 已提交
395
            tr1_cmd.strip().split(" "),
X
Xin Pan 已提交
396
            stdout=subprocess.PIPE,
G
gongweibao 已提交
397
            stderr=tr1_pipe,
X
Xin Pan 已提交
398 399
            env=env1)

400 401 402 403 404 405 406 407 408 409 410 411
        # Wait until trainer process terminate
        while True:
            stat0 = tr0_proc.poll()
            time.sleep(0.1)
            if stat0 is not None:
                break
        while True:
            stat1 = tr1_proc.poll()
            time.sleep(0.1)
            if stat1 is not None:
                break

412 413
        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
X
Xin Pan 已提交
414

G
gongweibao 已提交
415
        # close trainer file
416 417 418 419
        tr0_pipe.close()
        tr1_pipe.close()
        ps0_pipe.close()
        ps1_pipe.close()
W
Wu Yi 已提交
420

W
Wu Yi 已提交
421 422
        ps0.terminate()
        ps1.terminate()
T
typhoonzero 已提交
423

424 425 426 427 428 429
        # print server log
        with open("/tmp/ps0_err.log", "r") as fn:
            sys.stderr.write("ps0 stderr: %s\n" % fn.read())
        with open("/tmp/ps1_err.log", "r") as fn:
            sys.stderr.write("ps1 stderr: %s\n" % fn.read())

430
        # print log
431 432 433 434 435 436 437 438
        if stat0 == 0:
            sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
        with open("/tmp/tr0_err.log", "r") as fn:
            sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
        if stat1 == 0:
            sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
        with open("/tmp/tr1_err.log", "r") as fn:
            sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
439

W
Wu Yi 已提交
440 441 442 443 444 445 446
        return pickle.loads(tr0_out), pickle.loads(tr1_out)

    def _run_cluster_nccl2(self, model, envs, check_error_log):
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints

W
Wu Yi 已提交
447
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
W
Wu Yi 已提交
448 449
        tr0_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
450
                   0, w0_ep, self._lr)
W
Wu Yi 已提交
451 452
        tr1_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
Y
Yancey1989 已提交
453
                   1, w1_ep, self._lr)
W
Wu Yi 已提交
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504

        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
        if self.__use_cuda:
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)

        print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")

        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0)
        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1)

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()

        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()

        # print log
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
        sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)

W
Wu Yi 已提交
505
        return pickle.loads(tr0_out), pickle.loads(tr1_out)
506 507 508 509 510 511 512 513 514 515 516 517

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={}):
        # TODO(typhoonzero): should auto adapt GPU count on the machine.
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
518
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
519
            "FLAGS_cudnn_deterministic": "1",
W
Wu Yi 已提交
520 521
            "http_proxy": "",
            "NCCL_P2P_DISABLE": "1"
522 523 524 525 526
        }

        required_envs.update(need_envs)

        if check_error_log:
W
Wu Yi 已提交
527
            required_envs["GLOG_v"] = "3"
528 529 530 531 532
            required_envs["GLOG_logtostderr"] = "1"

        local_losses\
            = self._run_local(model_file, required_envs,
                                       check_error_log)
W
Wu Yi 已提交
533 534 535 536 537 538
        if self._nccl2_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file, required_envs, check_error_log)
        else:
            tr0_losses, tr1_losses = self._run_cluster(
                model_file, required_envs, check_error_log)
539 540

        for step_id in range(RUN_STEP):
W
Wu Yi 已提交
541 542 543 544 545 546
            local_loss = local_losses[step_id]
            tr0_loss = tr0_losses[step_id]
            tr1_loss = tr1_losses[step_id]
            dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
            print("=======", local_loss, ":", dist_loss[0], "=======")
            self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)