test_collective_api_base.py 23.9 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import unittest
import os
import sys
import subprocess
import pickle
21
import tempfile
22
from contextlib import closing
23
import paddle
24 25
import paddle.fluid as fluid
from paddle.fluid import core
26
from paddle_bfloat import bfloat16
27 28


29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
def create_bool_test_data(shape=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.choice([True, False], size=shape)
    return data


def create_float_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.random(shape).astype(dtype)
    return data


def create_int_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.randint(0, high=100, size=shape).astype(dtype)
    return data


def create_complex_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.random(shape).astype(dtype)
    data.imag = np.random.random(shape)
    return data


def create_pylist_test_data(shape=None, seed=None):
    if seed:
        np.random.seed(seed)
61 62
    # Generate random shape test case for xxx_object api
    shape = np.random.randint(0, high=100, size=(2)).tolist()
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    data = np.random.random(shape).tolist()
    return data


def create_pydict_test_data(shape=None, seed=None):
    if seed:
        np.random.seed(seed)
    key = [i for i in range(0, shape[0])]
    value = np.random.random(shape).tolist()
    data = dict(zip(key, value))
    return data


def create_test_data(shape=None, dtype=None, seed=None):
    assert shape, "Shape should be specified"
    if dtype == "float32" or dtype == "float16" or dtype == "float64":
        return create_float_test_data(shape=shape, dtype=dtype, seed=seed)
80 81 82
    elif dtype == "bfloat16":
        # since numpy does not support bfloat16 yet, use `paddle_bfloat` to replace
        return create_float_test_data(shape=shape, dtype=bfloat16, seed=seed)
83 84
    elif dtype == "bool":
        return create_bool_test_data(shape=shape, seed=seed)
85 86 87 88 89 90
    elif (
        dtype == "int32"
        or dtype == "int64"
        or dtype == "int8"
        or dtype == "uint8"
    ):
91 92 93 94 95 96 97 98 99 100 101
        return create_int_test_data(shape=shape, dtype=dtype, seed=seed)
    elif dtype == "complex64" or dtype == "complex128":
        return create_complex_test_data(shape=shape, dtype=dtype, seed=seed)
    elif dtype == "pylist":
        return create_pylist_test_data(shape=shape, seed=seed)
    elif dtype == "pydict":
        return create_pydict_test_data(shape=shape, seed=seed)
    else:
        raise NotImplementedError("Unsupported dtype for creating test data.")


102
class TestCollectiveAPIRunnerBase:
103 104 105
    def get_model(
        self, train_prog, startup_prog, rank, indata=None, dtype=None
    ):
106
        raise NotImplementedError(
107 108
            "get model should be implemented by child class."
        )
109 110 111 112 113 114 115 116

    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
117
        paddle.distributed.init_parallel_env()
118 119 120
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
121 122
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
123 124 125
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
126 127
        else:
            place = fluid.CPUPlace()
128 129 130
        indata = create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
L
lilong12 已提交
131 132 133 134 135 136 137
        if args['static_mode']:
            result = self.get_model(train_prog, startup_prog, rank)
            exe = fluid.Executor(place)
            exe.run(startup_prog)
            fetch_list = []
            for elem in result:
                fetch_list.append(elem.name)
138 139 140
            out = exe.run(
                train_prog, feed={'tindata': indata}, fetch_list=fetch_list
            )
L
lilong12 已提交
141 142
        else:
            out = self.get_model(train_prog, startup_prog, rank, indata)
143
            # print(out, sys.stderr)
T
tianshuo78520a 已提交
144
        sys.stdout.buffer.write(pickle.dumps(out))
145 146 147 148 149 150 151 152 153 154 155 156


def runtime_main(test_class, col_type):
    args = {}
    model = test_class()
    args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID"))
    args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM"))
    args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
    args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
    args["col_type"] = col_type
    args["backend"] = os.getenv("BACKEND")
    args["path_id"] = int(os.getenv("PATH_ID"))
L
lilong12 已提交
157
    args["static_mode"] = int(os.getenv("STATIC_MODE"))
158
    args["dtype"] = os.getenv("DTYPE")
159 160 161 162 163 164 165 166 167 168 169 170
    model.run_trainer(args)


import socket
from contextlib import closing


class TestDistBase(unittest.TestCase):
    def setUp(self):
        self._port_set = set()
        self._trainers = 2
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
171 172 173
            self._find_free_port(),
            self._find_free_port(),
        )
174 175
        self._python_interp = sys.executable

176 177
        self.temp_dir = tempfile.TemporaryDirectory()

178 179 180 181 182
        # NOTE: this is a hack to get int format nccl version, like 2134
        # if current platform is not linux, version number will be 0
        nccl_version_str = subprocess.check_output(
            r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
            stderr=subprocess.DEVNULL,
183 184 185 186 187
            shell=True,
        ).decode('utf-8')
        self._nccl_version = (
            int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
        )
188

189 190 191
    def tearDown(self):
        self.temp_dir.cleanup()

192 193
    def _find_free_port(self):
        def __free_port():
194 195 196
            with closing(
                socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            ) as s:
197 198 199 200 201 202 203 204 205 206 207 208
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port

    def _run_cluster(self, model_file, envs):
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints
209
        # print("w0_ep:",w0_ep," w1_ep:",w1_ep)
210 211 212 213 214 215
        if core.is_compiled_with_cuda():
            env0 = {
                "FLAGS_selected_gpus": "0",
                "PADDLE_TRAINER_ID": "0",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
216
                "PADDLE_CURRENT_ENDPOINT": w0_ep,
217 218 219 220 221 222 223
            }

            env1 = {
                "FLAGS_selected_gpus": "1",
                "PADDLE_TRAINER_ID": "1",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
224
                "PADDLE_CURRENT_ENDPOINT": w1_ep,
225 226 227 228 229 230 231
            }
        elif core.is_compiled_with_xpu():
            env0 = {
                "FLAGS_selected_xpus": "0",
                "PADDLE_TRAINER_ID": "0",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
232
                "PADDLE_CURRENT_ENDPOINT": w0_ep,
233 234 235 236 237 238 239
            }

            env1 = {
                "FLAGS_selected_xpus": "1",
                "PADDLE_TRAINER_ID": "1",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
240
                "PADDLE_CURRENT_ENDPOINT": w1_ep,
241
            }
242
        # update environment
243 244
        env0.update(envs)
        env1.update(envs)
245 246 247 248
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            tr_cmd = "%s -m coverage run --branch -p %s"
        else:
            tr_cmd = "%s %s"
249 250
        tr0_cmd = tr_cmd % (self._python_interp, model_file)
        tr1_cmd = tr_cmd % (self._python_interp, model_file)
251 252 253 254 255 256
        path0 = os.path.join(
            self.temp_dir.name, "/tmp/tr0_err_%d.log" % os.getpid()
        )
        path1 = os.path.join(
            self.temp_dir.name, "/tmp/tr1_err_%d.log" % os.getpid()
        )
257 258
        tr0_pipe = open(path0, "w")
        tr1_pipe = open(path1, "w")
259 260 261 262 263 264 265 266 267 268 269 270 271 272
        # print(tr0_cmd)
        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0,
        )

        tr1_proc = subprocess.Popen(
            tr0_cmd.strip().split(),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1,
        )
273 274 275 276 277 278 279 280

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()
281
        with open(path0, "r") as f:
282
            sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
283
        with open(path1, "r") as f:
284
            sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
        return (
            pickle.loads(tr0_out),
            pickle.loads(tr1_out),
            tr0_proc.pid,
            tr1_proc.pid,
        )

    def check_with_place(
        self,
        model_file,
        col_type,
        backend="nccl",
        path_id="0",
        static_mode="1",
        check_error_log=False,
        need_envs={},
        eager_mode=True,
        dtype=None,
    ):
304 305 306 307
        if backend == "nccl" or backend == "bkcl":
            with_gloo = '0'
        else:
            with_gloo = '1'
308
        required_envs = os.environ.copy()
309
        dtype = "float32" if dtype is None else dtype
310
        additional_envs = {
311
            "NCCL_P2P_DISABLE": "1",
L
lilong12 已提交
312
            "STATIC_MODE": static_mode,
L
lilong12 已提交
313
            "PADDLE_WITH_GLOO": with_gloo,
314
            "PADDLE_DISTRI_BACKEND": backend,
315
            "BACKEND": backend,
316
            "PATH_ID": path_id,
317
            "DTYPE": dtype,
318
        }
319
        required_envs.update(additional_envs)
320 321 322 323
        required_envs.update(need_envs)
        if check_error_log:
            required_envs["GLOG_v"] = "3"
            required_envs["GLOG_logtostderr"] = "1"
324
            required_envs["GLOO_LOG_LEVEL"] = "TRACE"
325

326 327
        if os.getenv('NVIDIA_TF32_OVERRIDE', '') is not None:
            required_envs['NVIDIA_TF32_OVERRIDE'] = os.getenv(
328 329
                'NVIDIA_TF32_OVERRIDE', ''
            )
330

331 332
        if eager_mode:
            required_envs["FLAGS_enable_eager_mode"] = "%d" % 1
333 334
        else:
            required_envs["FLAGS_enable_eager_mode"] = "%d" % 0
335

336
        tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
337 338
            model_file, required_envs
        )
339 340
        input1 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid0)
        input2 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid1)
341 342 343 344
        # cast bfloat16 to float32 for numeric comparison
        if dtype == "bfloat16":
            input1 = input1.astype("float32")
            input2 = input2.astype("float32")
345 346 347 348
        if col_type == "allgather":
            need_result = np.vstack((input1, input2))
            tr_out0 = np.vstack((tr0_out[0], tr0_out[1]))
            tr_out1 = np.vstack((tr1_out[0], tr1_out[1]))
349 350
            np.testing.assert_allclose(tr_out0, need_result, rtol=1e-05)
            np.testing.assert_allclose(tr_out1, need_result, rtol=1e-05)
351 352 353 354
        if col_type == "allgather_object":
            need_result = [input1, input2]
            self.assertEqual(need_result, tr0_out)
            self.assertEqual(need_result, tr1_out)
355 356
        elif col_type == "broadcast":
            need_result = input2
357 358
            np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05)
            np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05)
359 360
        elif col_type == "reduce":
            need_result = input1 + input2
361 362 363 364 365 366 367
            # bfloat16 precision loss comes from truncating the last 16 bits of float32,
            # which sums (\sum_{i=-23}^{-8}2^{i}) to about 0.0078
            if dtype == "bfloat16":
                rtol = 8e-03
            else:
                rtol = 1e-05
            np.testing.assert_allclose(tr0_out[0], need_result, rtol=rtol)
368 369
        elif col_type == "scatter":
            need_result = input2
370 371
            need_result1 = need_result[0 : need_result.shape[0] // 2]
            need_result2 = need_result[need_result.shape[0] // 2 :]
372 373
            np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
            np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
374 375
        elif col_type == "reduce_scatter":
            need_result = input1 + input2
376 377
            need_result1 = need_result[0 : need_result.shape[0] // 2]
            need_result2 = need_result[need_result.shape[0] // 2 :]
378 379 380 381 382 383
            if dtype == "bfloat16":
                rtol = 8e-03
            else:
                rtol = 1e-05
            np.testing.assert_allclose(tr0_out[0], need_result1, rtol=rtol)
            np.testing.assert_allclose(tr1_out[0], need_result2, rtol=rtol)
384 385
        elif col_type == "allreduce":
            need_result = input1 + input2
386 387 388 389 390 391
            if dtype == "bfloat16":
                rtol = 8e-03
                atol = 8e-03
            else:
                rtol = 1e-05
                atol = 1e-05
392 393 394 395 396 397
            np.testing.assert_allclose(
                tr0_out[0], need_result, rtol=rtol, atol=atol
            )
            np.testing.assert_allclose(
                tr1_out[0], need_result, rtol=rtol, atol=atol
            )
398 399 400
        elif col_type == "parallel_embedding":
            result_data = tr0_out[0]
            np.random.seed(2020)
401
            need_result = np.random.rand(12, 8)
402 403 404
            for i in range(result_data.shape[0]):
                for j in range(result_data.shape[1]):
                    data = result_data[i][j]
405 406 407
                    assert np.allclose(
                        tr0_out[1][i][j], need_result[data], atol=1e-08
                    )
408 409 410 411 412
        elif col_type == "row_parallel_linear":
            result_data = tr0_out[0]
            np.random.seed(2020)
            weight = np.random.rand(1000, 16)
            need_result = np.matmul(input1, weight)
413 414 415
            np.testing.assert_allclose(
                result_data, need_result, rtol=1e-05, atol=1e-05
            )
416 417 418 419 420
        elif col_type == "column_parallel_linear":
            result_data = tr0_out[0]
            np.random.seed(2020)
            weight = np.random.rand(1000, 16)
            need_result = np.matmul(input1, weight)
421 422 423
            np.testing.assert_allclose(
                result_data, need_result, rtol=1e-05, atol=1e-05
            )
L
lilong12 已提交
424
        elif col_type == "alltoall":
425 426 427 428 429 430 431 432 433 434 435 436
            need_result1 = np.vstack(
                (
                    input1[0 : input1.shape[0] // 2, :],
                    input2[0 : input2.shape[0] // 2, :],
                )
            )
            need_result2 = np.vstack(
                (
                    input1[input1.shape[0] // 2 :, :],
                    input2[input2.shape[0] // 2 :, :],
                )
            )
L
lilong12 已提交
437 438
            tr0_out = np.vstack(tr0_out)
            tr1_out = np.vstack(tr1_out)
439 440 441 442 443 444
            np.testing.assert_allclose(
                tr0_out, need_result1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out, need_result2, rtol=1e-05, atol=1e-05
            )
L
lilong12 已提交
445 446
        elif col_type == "sendrecv":
            result_data = tr1_out[0]
447 448 449
            np.testing.assert_allclose(
                input1, result_data, rtol=1e-05, atol=1e-05
            )
450 451 452 453 454 455 456 457
        elif col_type == "global_gather":
            in_feat = 2
            n_expert = 2
            world_size = 2
            tot_expert = n_expert * world_size

            np.random.seed(pid0)
            local_expert_count1 = np.random.randint(
458 459
                1, 4, size=tot_expert
            ).astype("int")
460 461 462 463 464 465 466
            expert_ptr1 = np.ones(tot_expert, dtype=np.int32)
            expert_ptr1[0] = 0
            for i in range(1, tot_expert):
                expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]

            np.random.seed(pid1)
            local_expert_count2 = np.random.randint(
467 468
                1, 4, size=tot_expert
            ).astype("int")
469 470 471 472 473 474 475 476 477 478 479 480 481 482
            expert_ptr2 = np.ones(tot_expert, dtype=np.int32)
            expert_ptr2[0] = 0
            for i in range(1, tot_expert):
                expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

            global_expert_count1 = np.zeros(tot_expert).astype("int")
            global_expert_count2 = np.zeros(tot_expert).astype("int")
            global_expert_count1[0:n_expert] = local_expert_count1[0:n_expert]
            global_expert_count1[n_expert:] = local_expert_count2[0:n_expert]
            global_expert_count2[0:n_expert] = local_expert_count1[n_expert:]
            global_expert_count2[n_expert:] = local_expert_count2[n_expert:]

            np.random.seed(pid0)
            fwd_expert_count = sum(global_expert_count1).astype("int")
483 484 485
            local_input_buf1 = np.random.rand(fwd_expert_count, in_feat).astype(
                "float32"
            )
486 487
            np.random.seed(pid1)
            fwd_expert_count = sum(global_expert_count2).astype("int")
488 489 490
            local_input_buf2 = np.random.rand(fwd_expert_count, in_feat).astype(
                "float32"
            )
491 492 493 494 495 496 497 498 499
            output1 = [[], [], [], []]
            output2 = [[], [], [], []]
            send_ptr1 = 0
            send_ptr2 = 0

            for i in range(n_expert):
                for j in range(world_size):
                    idx = j * n_expert + i
                    if j == 0:
500 501 502 503 504 505
                        output1_part1 = local_input_buf1[
                            send_ptr1 : send_ptr1 + global_expert_count1[idx], :
                        ]
                        output1_part2 = local_input_buf2[
                            send_ptr2 : send_ptr2 + global_expert_count2[idx], :
                        ]
506 507 508
                        output1[i].extend(output1_part1)
                        output1[i + n_expert].extend(output1_part2)
                    else:
509 510 511 512 513 514
                        output2_part1 = local_input_buf1[
                            send_ptr1 : send_ptr1 + global_expert_count1[idx]
                        ]
                        output2_part2 = local_input_buf2[
                            send_ptr2 : send_ptr2 + global_expert_count2[idx]
                        ]
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
                        output2[i].extend(output2_part1)
                        output2[i + n_expert].extend(output2_part2)
                    send_ptr1 = send_ptr1 + global_expert_count1[idx]
                    send_ptr2 = send_ptr2 + global_expert_count2[idx]
            result1 = []
            result2 = []
            for i in range(tot_expert):
                for arr in output1[i]:
                    if arr == []:
                        continue
                    result1.append(arr)
            for i in range(tot_expert):
                for arr in output2[i]:
                    if arr == []:
                        continue
                    result2.append(arr)
            if result1 == []:
                output1 = np.array([])
            else:
534
                output1 = np.concatenate(result1, axis=0).reshape(
535 536
                    sum(local_expert_count1), in_feat
                )
537 538 539
            if result2 == []:
                output2 = np.array([])
            else:
540
                output2 = np.concatenate(result2, axis=0).reshape(
541 542
                    sum(local_expert_count2), in_feat
                )
543 544 545 546 547 548 549

            if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
                tr0_out[0] = np.array([])

            if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
                tr1_out[0] = np.array([])

550 551 552 553 554 555
            np.testing.assert_allclose(
                tr0_out[0], output1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out[0], output2, rtol=1e-05, atol=1e-05
            )
556
            if static_mode == 0:
557 558 559 560 561 562
                np.testing.assert_allclose(
                    tr0_out[1], 2 * local_input_buf1, rtol=1e-05, atol=1e-05
                )
                np.testing.assert_allclose(
                    tr1_out[1], 2 * local_input_buf2, rtol=1e-05, atol=1e-05
                )
563 564 565 566 567

        elif col_type == "global_scatter":
            np.random.seed(pid0)
            local_expert_count1 = np.random.randint(1, 4, size=4).astype("int")
            fwd_expert_count = sum(local_expert_count1)
568 569 570
            local_input_buf1 = np.random.rand(fwd_expert_count, 2).astype(
                "float32"
            )
571 572 573 574 575 576 577
            expert_ptr1 = np.ones(4, dtype=np.int32)
            expert_ptr1[0] = 0
            for i in range(1, 4):
                expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]
            np.random.seed(pid1)
            local_expert_count2 = np.random.randint(1, 4, size=4).astype("int")
            fwd_expert_count = sum(local_expert_count2)
578 579 580
            local_input_buf2 = np.random.rand(fwd_expert_count, 2).astype(
                "float32"
            )
581 582 583 584 585 586 587 588 589 590 591 592
            expert_ptr2 = np.ones(4, dtype=np.int32)
            expert_ptr2[0] = 0
            for i in range(1, 4):
                expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

            output1 = []
            output2 = []
            for i in range(2):
                for j in range(2):
                    idx = j * 2 + i
                    if j == 0:
                        # send data to 0 card
593 594 595 596 597 598 599 600 601 602 603 604
                        output1.append(
                            local_input_buf1[
                                expert_ptr1[idx] : expert_ptr1[idx]
                                + local_expert_count1[idx]
                            ]
                        )
                        output1.append(
                            local_input_buf2[
                                expert_ptr2[idx] : expert_ptr2[idx]
                                + local_expert_count2[idx]
                            ]
                        )
605
                    else:
606 607 608 609 610 611 612 613 614 615 616 617
                        output2.append(
                            local_input_buf1[
                                expert_ptr1[idx] : expert_ptr1[idx]
                                + local_expert_count1[idx]
                            ]
                        )
                        output2.append(
                            local_input_buf2[
                                expert_ptr2[idx] : expert_ptr2[idx]
                                + local_expert_count2[idx]
                            ]
                        )
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
            if output1 == []:
                output1 = np.array([])
            else:
                output1 = np.concatenate(output1)
            if output2 == []:
                output2 = np.array([])
            else:
                output2 = np.concatenate(output2)

            if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
                tr0_out[0] = np.array([])

            if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
                tr1_out[0] = np.array([])

633 634 635 636 637 638
            np.testing.assert_allclose(
                tr0_out[0], output1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out[0], output2, rtol=1e-05, atol=1e-05
            )
639
            if static_mode == 0:
640 641 642 643 644 645
                np.testing.assert_allclose(
                    tr0_out[1], 2 * local_input_buf1, rtol=1e-05, atol=1e-05
                )
                np.testing.assert_allclose(
                    tr1_out[1], 2 * local_input_buf2, rtol=1e-05, atol=1e-05
                )
646 647
        else:
            pass