test_dist_base.py 18.8 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
X
Xin Pan 已提交
16 17 18 19 20 21 22
import time

import unittest
import os
import sys
import signal
import subprocess
23
import six
W
Wu Yi 已提交
24
import argparse
W
Wu Yi 已提交
25 26
import pickle
import numpy as np
T
typhoonzero 已提交
27

28 29 30
import paddle.fluid as fluid

RUN_STEP = 10
31
DEFAULT_BATCH_SIZE = 2
32

T
typhoonzero 已提交
33 34

class TestDistRunnerBase(object):
35
    def get_model(self, batch_size=DEFAULT_BATCH_SIZE):
T
typhoonzero 已提交
36 37 38
        raise NotImplementedError(
            "get_model should be implemented by child classes.")

39
    @staticmethod
W
Wu Yi 已提交
40 41 42 43 44 45
    def get_transpiler(trainer_id,
                       main_program,
                       pserver_endpoints,
                       trainers,
                       sync_mode,
                       dc_asgd=False):
T
typhoonzero 已提交
46
        # NOTE: import fluid until runtime, or else forking processes will cause error.
47
        config = fluid.DistributeTranspilerConfig()
W
Wu Yi 已提交
48
        config.enable_dc_asgd = dc_asgd
49
        t = fluid.DistributeTranspiler(config=config)
T
typhoonzero 已提交
50 51 52 53
        t.transpile(
            trainer_id=trainer_id,
            program=main_program,
            pservers=pserver_endpoints,
W
Wu Yi 已提交
54 55
            trainers=trainers,
            sync_mode=sync_mode)
T
typhoonzero 已提交
56 57
        return t

W
Wu Yi 已提交
58
    def run_pserver(self, args):
59
        self.get_model(batch_size=args.batch_size)
60
        # NOTE: pserver should not call memory optimize
W
Wu Yi 已提交
61 62
        t = self.get_transpiler(args.trainer_id,
                                fluid.default_main_program(), args.endpoints,
W
Wu Yi 已提交
63
                                args.trainers, args.sync_mode, args.dc_asgd)
W
Wu Yi 已提交
64 65 66
        pserver_prog = t.get_pserver_program(args.current_endpoint)
        startup_prog = t.get_startup_program(args.current_endpoint,
                                             pserver_prog)
Y
Yancey1989 已提交
67

T
typhoonzero 已提交
68 69 70 71 72
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)
        exe.run(pserver_prog)

73
    def run_trainer(self, args):
T
typhoonzero 已提交
74
        test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
75
            self.get_model(batch_size=args.batch_size)
76

W
Wu Yi 已提交
77
        if args.mem_opt:
78
            fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
W
Wu Yi 已提交
79
        if args.update_method == "pserver":
W
Wu Yi 已提交
80 81 82
            t = self.get_transpiler(args.trainer_id,
                                    fluid.default_main_program(),
                                    args.endpoints, args.trainers,
W
Wu Yi 已提交
83
                                    args.sync_mode, args.dc_asgd)
T
typhoonzero 已提交
84
            trainer_prog = t.get_trainer_program()
W
Wu Yi 已提交
85 86 87 88 89 90 91 92 93 94 95 96
        elif args.update_method == "nccl2":
            # transpile for nccl2
            config = fluid.DistributeTranspilerConfig()
            config.mode = "nccl2"
            nccl2_t = fluid.DistributeTranspiler(config=config)
            nccl2_t.transpile(
                args.trainer_id,
                program=fluid.default_main_program(),
                startup_program=fluid.default_startup_program(),
                trainers=args.endpoints,
                current_endpoint=args.current_endpoint)
            trainer_prog = fluid.default_main_program()
T
typhoonzero 已提交
97 98 99
        else:
            trainer_prog = fluid.default_main_program()

100 101 102 103 104
        if args.use_cuda:
            place = fluid.CUDAPlace(0)
        else:
            place = fluid.CPUPlace()

T
typhoonzero 已提交
105 106 107 108 109 110
        startup_exe = fluid.Executor(place)
        startup_exe.run(fluid.default_startup_program())

        strategy = fluid.ExecutionStrategy()
        strategy.num_threads = 1
        strategy.allow_op_delay = False
111

W
Wu Yi 已提交
112 113 114 115 116 117 118
        build_stra = fluid.BuildStrategy()

        if args.use_reduce:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
        else:
            build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce

X
Xin Pan 已提交
119
        if args.batch_merge_repeat > 1:
X
fix  
Xin Pan 已提交
120
            pass_builder = build_stra._finalize_strategy_and_create_passes()
X
Xin Pan 已提交
121 122 123 124
            mypass = pass_builder.insert_pass(
                len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
            mypass.set_int("num_repeats", args.batch_merge_repeat)

W
Wu Yi 已提交
125 126 127 128 129 130 131
        if args.update_method == "nccl2":
            num_trainers = len(args.endpoints.split(","))
            trainer_id = args.trainer_id
        else:
            num_trainers = 1
            trainer_id = 0

T
typhoonzero 已提交
132
        exe = fluid.ParallelExecutor(
133
            args.use_cuda,
W
Wu Yi 已提交
134 135
            loss_name=avg_cost.name,
            exec_strategy=strategy,
W
Wu Yi 已提交
136 137 138
            build_strategy=build_stra,
            num_trainers=num_trainers,
            trainer_id=trainer_id)
T
typhoonzero 已提交
139 140 141 142 143 144 145

        feed_var_list = [
            var for var in trainer_prog.global_block().vars.values()
            if var.is_data
        ]

        feeder = fluid.DataFeeder(feed_var_list, place)
146
        reader_generator = train_reader()
T
typhoonzero 已提交
147

148 149
        def get_data():
            origin_batch = next(reader_generator)
W
Wu Yi 已提交
150
            if args.update_method != "local" and args.use_reader_alloc:
151 152 153 154 155 156 157
                new_batch = []
                for offset, item in enumerate(origin_batch):
                    if offset % 2 == args.trainer_id:
                        new_batch.append(item)
                return new_batch
            else:
                return origin_batch
T
typhoonzero 已提交
158

W
Wu Yi 已提交
159
        out_losses = []
160 161 162
        for _ in six.moves.xrange(RUN_STEP):
            loss, = exe.run(fetch_list=[avg_cost.name],
                            feed=feeder.feed(get_data()))
W
Wu Yi 已提交
163 164 165 166 167
            out_losses.append(loss[0])
        if six.PY2:
            print(pickle.dumps(out_losses))
        else:
            sys.stdout.buffer.write(pickle.dumps(out_losses))
T
typhoonzero 已提交
168 169 170


def runtime_main(test_class):
W
Wu Yi 已提交
171 172 173 174
    parser = argparse.ArgumentParser(description='Run dist test.')
    parser.add_argument(
        '--role', type=str, required=True, choices=['pserver', 'trainer'])
    parser.add_argument('--endpoints', type=str, required=False, default="")
W
Wu Yi 已提交
175 176 177 178 179
    parser.add_argument(
        '--update_method',
        type=str,
        default="local",
        choices=["pserver", "nccl2", "local"])
W
Wu Yi 已提交
180 181 182 183 184 185
    parser.add_argument('--trainer_id', type=int, required=False, default=0)
    parser.add_argument('--trainers', type=int, required=False, default=1)
    parser.add_argument(
        '--current_endpoint', type=str, required=False, default="")
    parser.add_argument('--sync_mode', action='store_true')
    parser.add_argument('--mem_opt', action='store_true')
186
    parser.add_argument('--use_cuda', action='store_true')
W
Wu Yi 已提交
187
    parser.add_argument('--use_reduce', action='store_true')
W
Wu Yi 已提交
188
    parser.add_argument('--dc_asgd', action='store_true')
189
    parser.add_argument(
W
Wu Yi 已提交
190
        '--use_reader_alloc', action='store_true', required=False)
191 192 193
    parser.add_argument('--batch_size', required=False, type=int, default=2)
    parser.add_argument(
        '--batch_merge_repeat', required=False, type=int, default=1)
W
Wu Yi 已提交
194 195

    args = parser.parse_args()
T
typhoonzero 已提交
196 197

    model = test_class()
W
Wu Yi 已提交
198
    if args.role == "pserver" and args.update_method == "pserver":
W
Wu Yi 已提交
199
        model.run_pserver(args)
T
typhoonzero 已提交
200
    else:
201
        model.run_trainer(args)
X
Xin Pan 已提交
202

M
minqiyang 已提交
203

M
minqiyang 已提交
204
import paddle.compat as cpt
Y
Yancey1989 已提交
205 206
import socket
from contextlib import closing
M
minqiyang 已提交
207

X
Xin Pan 已提交
208 209

class TestDistBase(unittest.TestCase):
W
Wu Yi 已提交
210 211 212
    def _setup_config(self):
        raise NotImplementedError("tests should have _setup_config implemented")

213 214 215 216 217 218 219 220 221 222 223
    def _after_setup_config(self):
        if self._enforce_place == "CPU":
            self.__use_cuda = False
        elif self._enforce_place == "GPU":
            self.__use_cuda = True
        else:
            if fluid.core.is_compiled_with_cuda():
                self.__use_cuda = True
            else:
                self.__use_cuda = False

X
Xin Pan 已提交
224 225 226
    def setUp(self):
        self._trainers = 2
        self._pservers = 2
Y
Yancey1989 已提交
227
        self._port_set = set()
Y
Yancey1989 已提交
228 229
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
            self._find_free_port(), self._find_free_port())
M
minqiyang 已提交
230
        self._python_interp = sys.executable
W
Wu Yi 已提交
231
        self._sync_mode = True
232
        self._enforce_place = None
W
Wu Yi 已提交
233
        self._mem_opt = False
W
Wu Yi 已提交
234
        self._use_reduce = False
W
Wu Yi 已提交
235
        self._dc_asgd = False  # must use with async mode
236
        self._use_reader_alloc = True
W
Wu Yi 已提交
237
        self._nccl2_mode = False
W
Wu Yi 已提交
238
        self._setup_config()
239
        self._after_setup_config()
X
Xin Pan 已提交
240

Y
Yancey1989 已提交
241
    def _find_free_port(self):
Y
Yancey1989 已提交
242 243 244 245 246 247 248 249 250 251 252
        def __free_port():
            with closing(socket.socket(socket.AF_INET,
                                       socket.SOCK_STREAM)) as s:
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port
Y
Yancey1989 已提交
253

254
    def start_pserver(self, model_file, check_error_log, required_envs):
X
Xin Pan 已提交
255
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
W
Wu Yi 已提交
256
        ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --update_method pserver"
W
Wu Yi 已提交
257
        ps0_cmd = ps_cmd % \
258 259
                  (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
                   self._trainers)
W
Wu Yi 已提交
260
        ps1_cmd = ps_cmd % \
261 262
                  (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
                   self._trainers)
W
Wu Yi 已提交
263 264 265 266 267 268 269

        if self._sync_mode:
            ps0_cmd += " --sync_mode"
            ps1_cmd += " --sync_mode"
        if self._mem_opt:
            ps0_cmd += " --mem_opt"
            ps1_cmd += " --mem_opt"
X
Xin Pan 已提交
270

271 272
        print(ps0_cmd)
        print(ps1_cmd)
M
minqiyang 已提交
273 274
        ps0_pipe = open("/tmp/ps0_err.log", "wb")
        ps1_pipe = open("/tmp/ps1_err.log", "wb")
G
gongweibao 已提交
275

X
Xin Pan 已提交
276
        ps0_proc = subprocess.Popen(
277 278 279 280
            ps0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps0_pipe,
            env=required_envs)
X
Xin Pan 已提交
281
        ps1_proc = subprocess.Popen(
282 283 284 285
            ps1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=ps1_pipe,
            env=required_envs)
G
gongweibao 已提交
286

287
        return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
X
Xin Pan 已提交
288

289 290 291 292 293 294
    def _run_local(self,
                   model,
                   envs,
                   check_error_log=False,
                   batch_size=DEFAULT_BATCH_SIZE,
                   batch_merge_repeat=1):
G
gongweibao 已提交
295

296
        cmd = "%s %s --role trainer" % (self._python_interp, model)
297 298 299 300
        if batch_size != DEFAULT_BATCH_SIZE:
            cmd += " --batch_size %d" % batch_size
        if batch_merge_repeat > 1:
            cmd += " --batch_merge_repeat %d" % batch_merge_repeat
301

302
        if self.__use_cuda:
303 304 305 306 307
            cmd += " --use_cuda"
            env_local = {"CUDA_VISIBLE_DEVICES": "0"}
        else:
            env_local = {'CPU_NUM': '1'}

W
Wu Yi 已提交
308 309
        env_local.update(envs)
        print("local_cmd: {}, env: {}".format(cmd, env_local))
G
gongweibao 已提交
310

311
        if check_error_log:
312
            err_log = open("/tmp/trainer.err.log", "wb")
G
gongweibao 已提交
313
            local_proc = subprocess.Popen(
314
                cmd.split(" "),
G
gongweibao 已提交
315
                stdout=subprocess.PIPE,
316
                stderr=err_log,
W
Wu Yi 已提交
317
                env=env_local)
G
gongweibao 已提交
318 319
        else:
            local_proc = subprocess.Popen(
320
                cmd.split(" "),
G
gongweibao 已提交
321
                stdout=subprocess.PIPE,
322
                stderr=subprocess.PIPE,
W
Wu Yi 已提交
323
                env=env_local)
G
gongweibao 已提交
324

325 326 327 328 329 330
        local_out, local_err = local_proc.communicate()

        if check_error_log:
            err_log.close()

        sys.stderr.write('local_stderr: %s\n' % local_err)
W
Wu Yi 已提交
331
        sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
X
Xin Pan 已提交
332

W
Wu Yi 已提交
333
        return pickle.loads(local_out)
334 335

    def _run_cluster(self, model, envs, check_error_log):
X
Xin Pan 已提交
336
        # Run dist train to compare with local results
337 338
        ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
                                                          check_error_log, envs)
W
Wu Yi 已提交
339

X
Xin Pan 已提交
340
        ps0_ep, ps1_ep = self._ps_endpoints.split(",")
341

W
Wu Yi 已提交
342
        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver"
W
Wu Yi 已提交
343
        tr0_cmd = tr_cmd % \
344 345
                  (self._python_interp, model, self._ps_endpoints,
                   0, ps0_ep, self._trainers)
W
Wu Yi 已提交
346
        tr1_cmd = tr_cmd % \
347 348
                  (self._python_interp, model, self._ps_endpoints,
                   1, ps1_ep, self._trainers)
W
Wu Yi 已提交
349 350 351 352 353 354 355 356 357 358

        if self._sync_mode:
            tr0_cmd += " --sync_mode"
            tr1_cmd += " --sync_mode"
        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
359 360 361
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
362
        if self.__use_cuda:
363 364 365 366 367 368 369 370 371 372
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)
X
Xin Pan 已提交
373

W
Wu Yi 已提交
374 375
        print("tr0_cmd: {}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd: {}, env: {}".format(tr1_cmd, env1))
376 377
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")
G
gongweibao 已提交
378

X
Xin Pan 已提交
379
        tr0_proc = subprocess.Popen(
W
Wu Yi 已提交
380
            tr0_cmd.strip().split(" "),
X
Xin Pan 已提交
381
            stdout=subprocess.PIPE,
G
gongweibao 已提交
382
            stderr=tr0_pipe,
X
Xin Pan 已提交
383 384
            env=env0)
        tr1_proc = subprocess.Popen(
W
Wu Yi 已提交
385
            tr1_cmd.strip().split(" "),
X
Xin Pan 已提交
386
            stdout=subprocess.PIPE,
G
gongweibao 已提交
387
            stderr=tr1_pipe,
X
Xin Pan 已提交
388 389
            env=env1)

390 391 392 393 394 395 396 397 398 399 400 401
        # Wait until trainer process terminate
        while True:
            stat0 = tr0_proc.poll()
            time.sleep(0.1)
            if stat0 is not None:
                break
        while True:
            stat1 = tr1_proc.poll()
            time.sleep(0.1)
            if stat1 is not None:
                break

402 403
        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
X
Xin Pan 已提交
404

G
gongweibao 已提交
405
        # close trainer file
406 407 408 409
        tr0_pipe.close()
        tr1_pipe.close()
        ps0_pipe.close()
        ps1_pipe.close()
W
Wu Yi 已提交
410

W
Wu Yi 已提交
411 412
        ps0.terminate()
        ps1.terminate()
T
typhoonzero 已提交
413

414 415 416 417 418 419
        # print server log
        with open("/tmp/ps0_err.log", "r") as fn:
            sys.stderr.write("ps0 stderr: %s\n" % fn.read())
        with open("/tmp/ps1_err.log", "r") as fn:
            sys.stderr.write("ps1 stderr: %s\n" % fn.read())

420
        # print log
421 422 423 424 425 426 427 428
        if stat0 == 0:
            sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
        with open("/tmp/tr0_err.log", "r") as fn:
            sys.stderr.write('trainer 0 stderr: %s\n' % fn.read())
        if stat1 == 0:
            sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
        with open("/tmp/tr1_err.log", "r") as fn:
            sys.stderr.write('trainer 1 stderr: %s\n' % fn.read())
429

W
Wu Yi 已提交
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
        return pickle.loads(tr0_out), pickle.loads(tr1_out)

    def _run_cluster_nccl2(self, model, envs, check_error_log):
        # NOTE: we reuse ps_endpoints as nccl2 worker endpoints
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints

        tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2"
        tr0_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
                   0, w0_ep)
        tr1_cmd = tr_cmd % \
                  (self._python_interp, model, self._ps_endpoints,
                   1, w1_ep)

        if self._mem_opt:
            tr0_cmd += " --mem_opt"
            tr1_cmd += " --mem_opt"
        if self._use_reduce:
            tr0_cmd += " --use_reduce"
            tr1_cmd += " --use_reduce"
        if self._use_reader_alloc:
            tr0_cmd += " --use_reader_alloc"
            tr1_cmd += " --use_reader_alloc"
        if self.__use_cuda:
            tr0_cmd += " --use_cuda"
            tr1_cmd += " --use_cuda"
            env0 = {"CUDA_VISIBLE_DEVICES": "0"}
            env1 = {"CUDA_VISIBLE_DEVICES": "1"}
        else:
            env0 = {'CPU_NUM': '1'}
            env1 = {'CPU_NUM': '1'}

        env0.update(envs)
        env1.update(envs)

        print("tr0_cmd:{}, env: {}".format(tr0_cmd, env0))
        print("tr1_cmd:{}, env: {}".format(tr1_cmd, env1))
        tr0_pipe = open("/tmp/tr0_err.log", "wb")
        tr1_pipe = open("/tmp/tr1_err.log", "wb")

        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0)
        tr1_proc = subprocess.Popen(
            tr1_cmd.strip().split(" "),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1)

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()

        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()

        # print log
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        sys.stderr.write('trainer 0 stdout: %s\n' % tr0_out)
        sys.stderr.write('trainer 1 stdout: %s\n' % tr1_out)

W
Wu Yi 已提交
495
        return pickle.loads(tr0_out), pickle.loads(tr1_out)
496 497 498 499 500 501 502 503 504 505 506 507

    def check_with_place(self,
                         model_file,
                         delta=1e-3,
                         check_error_log=False,
                         need_envs={}):
        # TODO(typhoonzero): should auto adapt GPU count on the machine.
        required_envs = {
            "PATH": os.getenv("PATH", ""),
            "PYTHONPATH": os.getenv("PYTHONPATH", ""),
            "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
            "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
508
            "FLAGS_rpc_deadline": "5000",  # 5sec to fail fast
509
            "FLAGS_cudnn_deterministic": "1",
W
Wu Yi 已提交
510 511
            "http_proxy": "",
            "NCCL_P2P_DISABLE": "1"
512 513 514 515 516
        }

        required_envs.update(need_envs)

        if check_error_log:
W
Wu Yi 已提交
517
            required_envs["GLOG_v"] = "3"
518 519 520 521 522
            required_envs["GLOG_logtostderr"] = "1"

        local_losses\
            = self._run_local(model_file, required_envs,
                                       check_error_log)
W
Wu Yi 已提交
523 524 525 526 527 528
        if self._nccl2_mode:
            tr0_losses, tr1_losses = self._run_cluster_nccl2(
                model_file, required_envs, check_error_log)
        else:
            tr0_losses, tr1_losses = self._run_cluster(
                model_file, required_envs, check_error_log)
529 530

        for step_id in range(RUN_STEP):
W
Wu Yi 已提交
531 532 533 534 535 536
            local_loss = local_losses[step_id]
            tr0_loss = tr0_losses[step_id]
            tr1_loss = tr1_losses[step_id]
            dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
            print("=======", local_loss, ":", dist_loss[0], "=======")
            self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)