test_collective_api_base.py 23.7 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
17
import socket
18 19
import subprocess
import sys
20
import tempfile
21
import unittest
22
from contextlib import closing
23 24 25 26

import numpy as np
from paddle_bfloat import bfloat16

27
import paddle
28 29 30 31
import paddle.fluid as fluid
from paddle.fluid import core


32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
def create_bool_test_data(shape=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.choice([True, False], size=shape)
    return data


def create_float_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.random(shape).astype(dtype)
    return data


def create_int_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.randint(0, high=100, size=shape).astype(dtype)
    return data


def create_complex_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.random(shape).astype(dtype)
    data.imag = np.random.random(shape)
    return data


def create_pylist_test_data(shape=None, seed=None):
    if seed:
        np.random.seed(seed)
64 65
    # Generate random shape test case for xxx_object api
    shape = np.random.randint(0, high=100, size=(2)).tolist()
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    data = np.random.random(shape).tolist()
    return data


def create_pydict_test_data(shape=None, seed=None):
    if seed:
        np.random.seed(seed)
    key = [i for i in range(0, shape[0])]
    value = np.random.random(shape).tolist()
    data = dict(zip(key, value))
    return data


def create_test_data(shape=None, dtype=None, seed=None):
    assert shape, "Shape should be specified"
    if dtype == "float32" or dtype == "float16" or dtype == "float64":
        return create_float_test_data(shape=shape, dtype=dtype, seed=seed)
83 84 85
    elif dtype == "bfloat16":
        # since numpy does not support bfloat16 yet, use `paddle_bfloat` to replace
        return create_float_test_data(shape=shape, dtype=bfloat16, seed=seed)
86 87
    elif dtype == "bool":
        return create_bool_test_data(shape=shape, seed=seed)
88 89 90 91 92 93
    elif (
        dtype == "int32"
        or dtype == "int64"
        or dtype == "int8"
        or dtype == "uint8"
    ):
94 95 96 97 98 99 100 101 102 103 104
        return create_int_test_data(shape=shape, dtype=dtype, seed=seed)
    elif dtype == "complex64" or dtype == "complex128":
        return create_complex_test_data(shape=shape, dtype=dtype, seed=seed)
    elif dtype == "pylist":
        return create_pylist_test_data(shape=shape, seed=seed)
    elif dtype == "pydict":
        return create_pydict_test_data(shape=shape, seed=seed)
    else:
        raise NotImplementedError("Unsupported dtype for creating test data.")


105
class TestCollectiveAPIRunnerBase:
106 107 108
    def get_model(
        self, train_prog, startup_prog, rank, indata=None, dtype=None
    ):
109
        raise NotImplementedError(
110 111
            "get model should be implemented by child class."
        )
112 113 114 115 116 117 118 119

    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
120
        paddle.distributed.init_parallel_env()
121 122 123
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
124 125
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
126 127 128
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
129 130
        else:
            place = fluid.CPUPlace()
131 132 133
        indata = create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
L
lilong12 已提交
134 135 136 137 138 139 140
        if args['static_mode']:
            result = self.get_model(train_prog, startup_prog, rank)
            exe = fluid.Executor(place)
            exe.run(startup_prog)
            fetch_list = []
            for elem in result:
                fetch_list.append(elem.name)
141 142 143
            out = exe.run(
                train_prog, feed={'tindata': indata}, fetch_list=fetch_list
            )
L
lilong12 已提交
144 145
        else:
            out = self.get_model(train_prog, startup_prog, rank, indata)
146
            # print(out, sys.stderr)
T
tianshuo78520a 已提交
147
        sys.stdout.buffer.write(pickle.dumps(out))
148 149 150 151 152 153 154 155 156 157 158 159


def runtime_main(test_class, col_type):
    args = {}
    model = test_class()
    args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID"))
    args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM"))
    args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
    args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
    args["col_type"] = col_type
    args["backend"] = os.getenv("BACKEND")
    args["path_id"] = int(os.getenv("PATH_ID"))
L
lilong12 已提交
160
    args["static_mode"] = int(os.getenv("STATIC_MODE"))
161
    args["dtype"] = os.getenv("DTYPE")
162 163 164 165 166 167 168 169
    model.run_trainer(args)


class TestDistBase(unittest.TestCase):
    def setUp(self):
        self._port_set = set()
        self._trainers = 2
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
170 171 172
            self._find_free_port(),
            self._find_free_port(),
        )
173 174
        self._python_interp = sys.executable

175 176
        self.temp_dir = tempfile.TemporaryDirectory()

177 178 179 180 181
        # NOTE: this is a hack to get int format nccl version, like 2134
        # if current platform is not linux, version number will be 0
        nccl_version_str = subprocess.check_output(
            r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
            stderr=subprocess.DEVNULL,
182 183 184 185 186
            shell=True,
        ).decode('utf-8')
        self._nccl_version = (
            int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
        )
187

188 189 190
    def tearDown(self):
        self.temp_dir.cleanup()

191 192
    def _find_free_port(self):
        def __free_port():
193 194 195
            with closing(
                socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            ) as s:
196 197 198 199 200 201 202 203 204 205 206 207
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port

    def _run_cluster(self, model_file, envs):
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints
208
        # print("w0_ep:",w0_ep," w1_ep:",w1_ep)
209 210 211 212 213 214
        if core.is_compiled_with_cuda():
            env0 = {
                "FLAGS_selected_gpus": "0",
                "PADDLE_TRAINER_ID": "0",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
215
                "PADDLE_CURRENT_ENDPOINT": w0_ep,
216 217 218 219 220 221 222
            }

            env1 = {
                "FLAGS_selected_gpus": "1",
                "PADDLE_TRAINER_ID": "1",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
223
                "PADDLE_CURRENT_ENDPOINT": w1_ep,
224 225 226 227 228 229 230
            }
        elif core.is_compiled_with_xpu():
            env0 = {
                "FLAGS_selected_xpus": "0",
                "PADDLE_TRAINER_ID": "0",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
231
                "PADDLE_CURRENT_ENDPOINT": w0_ep,
232 233 234 235 236 237 238
            }

            env1 = {
                "FLAGS_selected_xpus": "1",
                "PADDLE_TRAINER_ID": "1",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
239
                "PADDLE_CURRENT_ENDPOINT": w1_ep,
240
            }
241
        # update environment
242 243
        env0.update(envs)
        env1.update(envs)
244 245 246 247
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            tr_cmd = "%s -m coverage run --branch -p %s"
        else:
            tr_cmd = "%s %s"
248 249
        tr0_cmd = tr_cmd % (self._python_interp, model_file)
        tr1_cmd = tr_cmd % (self._python_interp, model_file)
250 251 252 253 254 255
        path0 = os.path.join(
            self.temp_dir.name, "/tmp/tr0_err_%d.log" % os.getpid()
        )
        path1 = os.path.join(
            self.temp_dir.name, "/tmp/tr1_err_%d.log" % os.getpid()
        )
256 257
        tr0_pipe = open(path0, "w")
        tr1_pipe = open(path1, "w")
258 259 260 261 262 263 264 265 266 267 268 269 270 271
        # print(tr0_cmd)
        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0,
        )

        tr1_proc = subprocess.Popen(
            tr0_cmd.strip().split(),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1,
        )
272 273 274 275 276 277 278 279

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()
280
        with open(path0, "r") as f:
281
            sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
282
        with open(path1, "r") as f:
283
            sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
        return (
            pickle.loads(tr0_out),
            pickle.loads(tr1_out),
            tr0_proc.pid,
            tr1_proc.pid,
        )

    def check_with_place(
        self,
        model_file,
        col_type,
        backend="nccl",
        path_id="0",
        static_mode="1",
        check_error_log=False,
        need_envs={},
        eager_mode=True,
        dtype=None,
    ):
303 304 305 306
        if backend == "nccl" or backend == "bkcl":
            with_gloo = '0'
        else:
            with_gloo = '1'
307
        required_envs = os.environ.copy()
308
        dtype = "float32" if dtype is None else dtype
309
        additional_envs = {
310
            "NCCL_P2P_DISABLE": "1",
L
lilong12 已提交
311
            "STATIC_MODE": static_mode,
L
lilong12 已提交
312
            "PADDLE_WITH_GLOO": with_gloo,
313
            "PADDLE_DISTRI_BACKEND": backend,
314
            "BACKEND": backend,
315
            "PATH_ID": path_id,
316
            "DTYPE": dtype,
317
        }
318
        required_envs.update(additional_envs)
319 320 321 322
        required_envs.update(need_envs)
        if check_error_log:
            required_envs["GLOG_v"] = "3"
            required_envs["GLOG_logtostderr"] = "1"
323
            required_envs["GLOO_LOG_LEVEL"] = "TRACE"
324

325 326
        if os.getenv('NVIDIA_TF32_OVERRIDE', '') is not None:
            required_envs['NVIDIA_TF32_OVERRIDE'] = os.getenv(
327 328
                'NVIDIA_TF32_OVERRIDE', ''
            )
329

330
        tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
331 332
            model_file, required_envs
        )
333 334
        input1 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid0)
        input2 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid1)
335 336 337 338
        # cast bfloat16 to float32 for numeric comparison
        if dtype == "bfloat16":
            input1 = input1.astype("float32")
            input2 = input2.astype("float32")
339 340 341 342
        if col_type == "allgather":
            need_result = np.vstack((input1, input2))
            tr_out0 = np.vstack((tr0_out[0], tr0_out[1]))
            tr_out1 = np.vstack((tr1_out[0], tr1_out[1]))
343 344
            np.testing.assert_allclose(tr_out0, need_result, rtol=1e-05)
            np.testing.assert_allclose(tr_out1, need_result, rtol=1e-05)
345 346 347 348
        if col_type == "allgather_object":
            need_result = [input1, input2]
            self.assertEqual(need_result, tr0_out)
            self.assertEqual(need_result, tr1_out)
349 350
        elif col_type == "broadcast":
            need_result = input2
351 352
            np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05)
            np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05)
353 354
        elif col_type == "reduce":
            need_result = input1 + input2
355 356 357 358 359 360 361
            # bfloat16 precision loss comes from truncating the last 16 bits of float32,
            # which sums (\sum_{i=-23}^{-8}2^{i}) to about 0.0078
            if dtype == "bfloat16":
                rtol = 8e-03
            else:
                rtol = 1e-05
            np.testing.assert_allclose(tr0_out[0], need_result, rtol=rtol)
362 363
        elif col_type == "scatter":
            need_result = input2
364 365
            need_result1 = need_result[0 : need_result.shape[0] // 2]
            need_result2 = need_result[need_result.shape[0] // 2 :]
366 367
            np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
            np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
368 369
        elif col_type == "reduce_scatter":
            need_result = input1 + input2
370 371
            need_result1 = need_result[0 : need_result.shape[0] // 2]
            need_result2 = need_result[need_result.shape[0] // 2 :]
372 373 374 375 376 377
            if dtype == "bfloat16":
                rtol = 8e-03
            else:
                rtol = 1e-05
            np.testing.assert_allclose(tr0_out[0], need_result1, rtol=rtol)
            np.testing.assert_allclose(tr1_out[0], need_result2, rtol=rtol)
378 379
        elif col_type == "allreduce":
            need_result = input1 + input2
380 381 382 383 384 385
            if dtype == "bfloat16":
                rtol = 8e-03
                atol = 8e-03
            else:
                rtol = 1e-05
                atol = 1e-05
386 387 388 389 390 391
            np.testing.assert_allclose(
                tr0_out[0], need_result, rtol=rtol, atol=atol
            )
            np.testing.assert_allclose(
                tr1_out[0], need_result, rtol=rtol, atol=atol
            )
392 393 394
        elif col_type == "parallel_embedding":
            result_data = tr0_out[0]
            np.random.seed(2020)
395
            need_result = np.random.rand(12, 8)
396 397 398
            for i in range(result_data.shape[0]):
                for j in range(result_data.shape[1]):
                    data = result_data[i][j]
399
                    np.testing.assert_allclose(
400 401
                        tr0_out[1][i][j], need_result[data], atol=1e-08
                    )
402 403 404 405 406
        elif col_type == "row_parallel_linear":
            result_data = tr0_out[0]
            np.random.seed(2020)
            weight = np.random.rand(1000, 16)
            need_result = np.matmul(input1, weight)
407 408 409
            np.testing.assert_allclose(
                result_data, need_result, rtol=1e-05, atol=1e-05
            )
410 411 412 413 414
        elif col_type == "column_parallel_linear":
            result_data = tr0_out[0]
            np.random.seed(2020)
            weight = np.random.rand(1000, 16)
            need_result = np.matmul(input1, weight)
415 416 417
            np.testing.assert_allclose(
                result_data, need_result, rtol=1e-05, atol=1e-05
            )
L
lilong12 已提交
418
        elif col_type == "alltoall":
419 420 421 422 423 424 425 426 427 428 429 430
            need_result1 = np.vstack(
                (
                    input1[0 : input1.shape[0] // 2, :],
                    input2[0 : input2.shape[0] // 2, :],
                )
            )
            need_result2 = np.vstack(
                (
                    input1[input1.shape[0] // 2 :, :],
                    input2[input2.shape[0] // 2 :, :],
                )
            )
L
lilong12 已提交
431 432
            tr0_out = np.vstack(tr0_out)
            tr1_out = np.vstack(tr1_out)
433 434 435 436 437 438
            np.testing.assert_allclose(
                tr0_out, need_result1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out, need_result2, rtol=1e-05, atol=1e-05
            )
L
lilong12 已提交
439 440
        elif col_type == "sendrecv":
            result_data = tr1_out[0]
441 442 443
            np.testing.assert_allclose(
                input1, result_data, rtol=1e-05, atol=1e-05
            )
444 445 446 447 448 449 450 451
        elif col_type == "global_gather":
            in_feat = 2
            n_expert = 2
            world_size = 2
            tot_expert = n_expert * world_size

            np.random.seed(pid0)
            local_expert_count1 = np.random.randint(
452 453
                1, 4, size=tot_expert
            ).astype("int")
454 455 456 457 458 459 460
            expert_ptr1 = np.ones(tot_expert, dtype=np.int32)
            expert_ptr1[0] = 0
            for i in range(1, tot_expert):
                expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]

            np.random.seed(pid1)
            local_expert_count2 = np.random.randint(
461 462
                1, 4, size=tot_expert
            ).astype("int")
463 464 465 466 467 468 469 470 471 472 473 474 475 476
            expert_ptr2 = np.ones(tot_expert, dtype=np.int32)
            expert_ptr2[0] = 0
            for i in range(1, tot_expert):
                expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

            global_expert_count1 = np.zeros(tot_expert).astype("int")
            global_expert_count2 = np.zeros(tot_expert).astype("int")
            global_expert_count1[0:n_expert] = local_expert_count1[0:n_expert]
            global_expert_count1[n_expert:] = local_expert_count2[0:n_expert]
            global_expert_count2[0:n_expert] = local_expert_count1[n_expert:]
            global_expert_count2[n_expert:] = local_expert_count2[n_expert:]

            np.random.seed(pid0)
            fwd_expert_count = sum(global_expert_count1).astype("int")
477 478 479
            local_input_buf1 = np.random.rand(fwd_expert_count, in_feat).astype(
                "float32"
            )
480 481
            np.random.seed(pid1)
            fwd_expert_count = sum(global_expert_count2).astype("int")
482 483 484
            local_input_buf2 = np.random.rand(fwd_expert_count, in_feat).astype(
                "float32"
            )
485 486 487 488 489 490 491 492 493
            output1 = [[], [], [], []]
            output2 = [[], [], [], []]
            send_ptr1 = 0
            send_ptr2 = 0

            for i in range(n_expert):
                for j in range(world_size):
                    idx = j * n_expert + i
                    if j == 0:
494 495 496 497 498 499
                        output1_part1 = local_input_buf1[
                            send_ptr1 : send_ptr1 + global_expert_count1[idx], :
                        ]
                        output1_part2 = local_input_buf2[
                            send_ptr2 : send_ptr2 + global_expert_count2[idx], :
                        ]
500 501 502
                        output1[i].extend(output1_part1)
                        output1[i + n_expert].extend(output1_part2)
                    else:
503 504 505 506 507 508
                        output2_part1 = local_input_buf1[
                            send_ptr1 : send_ptr1 + global_expert_count1[idx]
                        ]
                        output2_part2 = local_input_buf2[
                            send_ptr2 : send_ptr2 + global_expert_count2[idx]
                        ]
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
                        output2[i].extend(output2_part1)
                        output2[i + n_expert].extend(output2_part2)
                    send_ptr1 = send_ptr1 + global_expert_count1[idx]
                    send_ptr2 = send_ptr2 + global_expert_count2[idx]
            result1 = []
            result2 = []
            for i in range(tot_expert):
                for arr in output1[i]:
                    if arr == []:
                        continue
                    result1.append(arr)
            for i in range(tot_expert):
                for arr in output2[i]:
                    if arr == []:
                        continue
                    result2.append(arr)
            if result1 == []:
                output1 = np.array([])
            else:
528
                output1 = np.concatenate(result1, axis=0).reshape(
529 530
                    sum(local_expert_count1), in_feat
                )
531 532 533
            if result2 == []:
                output2 = np.array([])
            else:
534
                output2 = np.concatenate(result2, axis=0).reshape(
535 536
                    sum(local_expert_count2), in_feat
                )
537 538 539 540 541 542 543

            if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
                tr0_out[0] = np.array([])

            if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
                tr1_out[0] = np.array([])

544 545 546 547 548 549
            np.testing.assert_allclose(
                tr0_out[0], output1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out[0], output2, rtol=1e-05, atol=1e-05
            )
550
            if static_mode == 0:
551 552 553 554 555 556
                np.testing.assert_allclose(
                    tr0_out[1], 2 * local_input_buf1, rtol=1e-05, atol=1e-05
                )
                np.testing.assert_allclose(
                    tr1_out[1], 2 * local_input_buf2, rtol=1e-05, atol=1e-05
                )
557 558 559 560 561

        elif col_type == "global_scatter":
            np.random.seed(pid0)
            local_expert_count1 = np.random.randint(1, 4, size=4).astype("int")
            fwd_expert_count = sum(local_expert_count1)
562 563 564
            local_input_buf1 = np.random.rand(fwd_expert_count, 2).astype(
                "float32"
            )
565 566 567 568 569 570 571
            expert_ptr1 = np.ones(4, dtype=np.int32)
            expert_ptr1[0] = 0
            for i in range(1, 4):
                expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]
            np.random.seed(pid1)
            local_expert_count2 = np.random.randint(1, 4, size=4).astype("int")
            fwd_expert_count = sum(local_expert_count2)
572 573 574
            local_input_buf2 = np.random.rand(fwd_expert_count, 2).astype(
                "float32"
            )
575 576 577 578 579 580 581 582 583 584 585 586
            expert_ptr2 = np.ones(4, dtype=np.int32)
            expert_ptr2[0] = 0
            for i in range(1, 4):
                expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

            output1 = []
            output2 = []
            for i in range(2):
                for j in range(2):
                    idx = j * 2 + i
                    if j == 0:
                        # send data to 0 card
587 588 589 590 591 592 593 594 595 596 597 598
                        output1.append(
                            local_input_buf1[
                                expert_ptr1[idx] : expert_ptr1[idx]
                                + local_expert_count1[idx]
                            ]
                        )
                        output1.append(
                            local_input_buf2[
                                expert_ptr2[idx] : expert_ptr2[idx]
                                + local_expert_count2[idx]
                            ]
                        )
599
                    else:
600 601 602 603 604 605 606 607 608 609 610 611
                        output2.append(
                            local_input_buf1[
                                expert_ptr1[idx] : expert_ptr1[idx]
                                + local_expert_count1[idx]
                            ]
                        )
                        output2.append(
                            local_input_buf2[
                                expert_ptr2[idx] : expert_ptr2[idx]
                                + local_expert_count2[idx]
                            ]
                        )
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
            if output1 == []:
                output1 = np.array([])
            else:
                output1 = np.concatenate(output1)
            if output2 == []:
                output2 = np.array([])
            else:
                output2 = np.concatenate(output2)

            if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
                tr0_out[0] = np.array([])

            if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
                tr1_out[0] = np.array([])

627 628 629 630 631 632
            np.testing.assert_allclose(
                tr0_out[0], output1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out[0], output2, rtol=1e-05, atol=1e-05
            )
633
            if static_mode == 0:
634 635 636 637 638 639
                np.testing.assert_allclose(
                    tr0_out[1], 2 * local_input_buf1, rtol=1e-05, atol=1e-05
                )
                np.testing.assert_allclose(
                    tr1_out[1], 2 * local_input_buf2, rtol=1e-05, atol=1e-05
                )
640 641
        else:
            pass