test_collective_api_base.py 24.0 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
17
import socket
18 19
import subprocess
import sys
20
import tempfile
21
import unittest
22
from contextlib import closing
23 24 25 26

import numpy as np
from paddle_bfloat import bfloat16

27
import paddle
28 29 30 31
import paddle.fluid as fluid
from paddle.fluid import core


32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
def create_bool_test_data(shape=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.choice([True, False], size=shape)
    return data


def create_float_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.random(shape).astype(dtype)
    return data


def create_int_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.randint(0, high=100, size=shape).astype(dtype)
    return data


def create_complex_test_data(shape=None, dtype=None, seed=None):
    if seed:
        np.random.seed(seed)
    data = np.random.random(shape).astype(dtype)
    data.imag = np.random.random(shape)
    return data


61
def create_pyobject_test_data(shape=None, seed=None):
62 63
    if seed:
        np.random.seed(seed)
64 65 66 67 68 69
    list_shape = np.random.randint(0, high=100, size=(2)).tolist()
    list_data = np.random.random(shape).tolist()
    dict_key = [i for i in range(0, shape[0])]
    dict_val = np.random.random(shape).tolist()
    dict_data = dict(zip(dict_key, dict_val))
    return [list_data, dict_data]
70 71 72 73 74 75


def create_test_data(shape=None, dtype=None, seed=None):
    assert shape, "Shape should be specified"
    if dtype == "float32" or dtype == "float16" or dtype == "float64":
        return create_float_test_data(shape=shape, dtype=dtype, seed=seed)
76 77 78
    elif dtype == "bfloat16":
        # since numpy does not support bfloat16 yet, use `paddle_bfloat` to replace
        return create_float_test_data(shape=shape, dtype=bfloat16, seed=seed)
79 80
    elif dtype == "bool":
        return create_bool_test_data(shape=shape, seed=seed)
81 82 83 84 85 86
    elif (
        dtype == "int32"
        or dtype == "int64"
        or dtype == "int8"
        or dtype == "uint8"
    ):
87 88 89
        return create_int_test_data(shape=shape, dtype=dtype, seed=seed)
    elif dtype == "complex64" or dtype == "complex128":
        return create_complex_test_data(shape=shape, dtype=dtype, seed=seed)
90 91
    elif dtype == "pyobject":
        return create_pyobject_test_data(shape=shape, seed=seed)
92 93 94 95
    else:
        raise NotImplementedError("Unsupported dtype for creating test data.")


96
class TestCollectiveAPIRunnerBase:
97 98 99
    def get_model(
        self, train_prog, startup_prog, rank, indata=None, dtype=None
    ):
100
        raise NotImplementedError(
101 102
            "get model should be implemented by child class."
        )
103 104 105 106 107 108 109 110

    def run_trainer(self, args):
        train_prog = fluid.Program()
        startup_prog = fluid.Program()
        endpoints = args["endpoints"].split(",")
        rank = args["trainerid"]
        current_endpoint = args["currentendpoint"]
        nranks = 2
111
        paddle.distributed.init_parallel_env()
112 113 114
        if args['backend'] == 'nccl':
            device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
            place = fluid.CUDAPlace(
115 116
                device_id
            )  # if args.use_gpu else fluid.CPUPlace()
117 118 119
        elif args['backend'] == 'bkcl':
            device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
            place = fluid.XPUPlace(device_id)
120 121
        else:
            place = fluid.CPUPlace()
122 123 124
        indata = create_test_data(
            shape=(10, 1000), dtype=args["dtype"], seed=os.getpid()
        )
L
lilong12 已提交
125 126 127 128 129 130 131
        if args['static_mode']:
            result = self.get_model(train_prog, startup_prog, rank)
            exe = fluid.Executor(place)
            exe.run(startup_prog)
            fetch_list = []
            for elem in result:
                fetch_list.append(elem.name)
132 133 134
            out = exe.run(
                train_prog, feed={'tindata': indata}, fetch_list=fetch_list
            )
L
lilong12 已提交
135 136
        else:
            out = self.get_model(train_prog, startup_prog, rank, indata)
137
            # print(out, sys.stderr)
T
tianshuo78520a 已提交
138
        sys.stdout.buffer.write(pickle.dumps(out))
139 140 141 142 143 144 145 146 147 148 149 150


def runtime_main(test_class, col_type):
    args = {}
    model = test_class()
    args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID"))
    args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM"))
    args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
    args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
    args["col_type"] = col_type
    args["backend"] = os.getenv("BACKEND")
    args["path_id"] = int(os.getenv("PATH_ID"))
L
lilong12 已提交
151
    args["static_mode"] = int(os.getenv("STATIC_MODE"))
152
    args["dtype"] = os.getenv("DTYPE")
153 154 155 156 157 158 159 160
    model.run_trainer(args)


class TestDistBase(unittest.TestCase):
    def setUp(self):
        self._port_set = set()
        self._trainers = 2
        self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
161 162 163
            self._find_free_port(),
            self._find_free_port(),
        )
164 165
        self._python_interp = sys.executable

166 167
        self.temp_dir = tempfile.TemporaryDirectory()

168 169 170 171 172
        # NOTE: this is a hack to get int format nccl version, like 2134
        # if current platform is not linux, version number will be 0
        nccl_version_str = subprocess.check_output(
            r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
            stderr=subprocess.DEVNULL,
173 174 175 176 177
            shell=True,
        ).decode('utf-8')
        self._nccl_version = (
            int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
        )
178

179 180 181
    def tearDown(self):
        self.temp_dir.cleanup()

182 183
    def _find_free_port(self):
        def __free_port():
184 185 186
            with closing(
                socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            ) as s:
187 188 189 190 191 192 193 194 195 196 197 198
                s.bind(('', 0))
                return s.getsockname()[1]

        while True:
            port = __free_port()
            if port not in self._port_set:
                self._port_set.add(port)
                return port

    def _run_cluster(self, model_file, envs):
        worker_endpoints = self._ps_endpoints.split(",")
        w0_ep, w1_ep = worker_endpoints
199
        # print("w0_ep:",w0_ep," w1_ep:",w1_ep)
200 201 202 203 204 205
        if core.is_compiled_with_cuda():
            env0 = {
                "FLAGS_selected_gpus": "0",
                "PADDLE_TRAINER_ID": "0",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
206
                "PADDLE_CURRENT_ENDPOINT": w0_ep,
207 208 209 210 211 212 213
            }

            env1 = {
                "FLAGS_selected_gpus": "1",
                "PADDLE_TRAINER_ID": "1",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
214
                "PADDLE_CURRENT_ENDPOINT": w1_ep,
215 216 217 218 219 220 221
            }
        elif core.is_compiled_with_xpu():
            env0 = {
                "FLAGS_selected_xpus": "0",
                "PADDLE_TRAINER_ID": "0",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
222
                "PADDLE_CURRENT_ENDPOINT": w0_ep,
223 224 225 226 227 228 229
            }

            env1 = {
                "FLAGS_selected_xpus": "1",
                "PADDLE_TRAINER_ID": "1",
                "PADDLE_TRAINERS_NUM": "2",
                "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
230
                "PADDLE_CURRENT_ENDPOINT": w1_ep,
231
            }
232
        # update environment
233 234
        env0.update(envs)
        env1.update(envs)
235 236 237 238
        if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
            tr_cmd = "%s -m coverage run --branch -p %s"
        else:
            tr_cmd = "%s %s"
239 240
        tr0_cmd = tr_cmd % (self._python_interp, model_file)
        tr1_cmd = tr_cmd % (self._python_interp, model_file)
241 242 243 244 245 246
        path0 = os.path.join(
            self.temp_dir.name, "/tmp/tr0_err_%d.log" % os.getpid()
        )
        path1 = os.path.join(
            self.temp_dir.name, "/tmp/tr1_err_%d.log" % os.getpid()
        )
247 248
        tr0_pipe = open(path0, "w")
        tr1_pipe = open(path1, "w")
249 250 251 252 253 254 255 256 257 258 259 260 261 262
        # print(tr0_cmd)
        tr0_proc = subprocess.Popen(
            tr0_cmd.strip().split(),
            stdout=subprocess.PIPE,
            stderr=tr0_pipe,
            env=env0,
        )

        tr1_proc = subprocess.Popen(
            tr0_cmd.strip().split(),
            stdout=subprocess.PIPE,
            stderr=tr1_pipe,
            env=env1,
        )
263 264 265 266 267 268 269 270

        tr0_out, tr0_err = tr0_proc.communicate()
        tr1_out, tr1_err = tr1_proc.communicate()
        sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
        sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
        # close trainer file
        tr0_pipe.close()
        tr1_pipe.close()
271
        with open(path0, "r") as f:
272
            sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
273
        with open(path1, "r") as f:
274
            sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
        return (
            pickle.loads(tr0_out),
            pickle.loads(tr1_out),
            tr0_proc.pid,
            tr1_proc.pid,
        )

    def check_with_place(
        self,
        model_file,
        col_type,
        backend="nccl",
        path_id="0",
        static_mode="1",
        check_error_log=False,
        need_envs={},
        eager_mode=True,
        dtype=None,
    ):
294 295 296 297
        if backend == "nccl" or backend == "bkcl":
            with_gloo = '0'
        else:
            with_gloo = '1'
298
        required_envs = os.environ.copy()
299
        dtype = "float32" if dtype is None else dtype
300
        additional_envs = {
301
            "NCCL_P2P_DISABLE": "1",
L
lilong12 已提交
302
            "STATIC_MODE": static_mode,
L
lilong12 已提交
303
            "PADDLE_WITH_GLOO": with_gloo,
304
            "PADDLE_DISTRI_BACKEND": backend,
305
            "BACKEND": backend,
306
            "PATH_ID": path_id,
307
            "DTYPE": dtype,
308
        }
309
        required_envs.update(additional_envs)
310 311 312 313
        required_envs.update(need_envs)
        if check_error_log:
            required_envs["GLOG_v"] = "3"
            required_envs["GLOG_logtostderr"] = "1"
314
            required_envs["GLOO_LOG_LEVEL"] = "TRACE"
315

316 317
        if os.getenv('NVIDIA_TF32_OVERRIDE', '') is not None:
            required_envs['NVIDIA_TF32_OVERRIDE'] = os.getenv(
318 319
                'NVIDIA_TF32_OVERRIDE', ''
            )
320

321
        tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
322 323
            model_file, required_envs
        )
324 325
        input1 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid0)
        input2 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid1)
326 327 328 329
        # cast bfloat16 to float32 for numeric comparison
        if dtype == "bfloat16":
            input1 = input1.astype("float32")
            input2 = input2.astype("float32")
330 331 332 333
        if col_type == "allgather":
            need_result = np.vstack((input1, input2))
            tr_out0 = np.vstack((tr0_out[0], tr0_out[1]))
            tr_out1 = np.vstack((tr1_out[0], tr1_out[1]))
334 335
            np.testing.assert_allclose(tr_out0, need_result, rtol=1e-05)
            np.testing.assert_allclose(tr_out1, need_result, rtol=1e-05)
336
        elif col_type == "allgather_object":
337 338 339
            need_result = [input1, input2]
            self.assertEqual(need_result, tr0_out)
            self.assertEqual(need_result, tr1_out)
340 341
        elif col_type == "broadcast":
            need_result = input2
342 343
            np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05)
            np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05)
344 345 346 347
        elif col_type == "broadcast_object_list":
            need_result = [input2]
            self.assertEqual(need_result, tr0_out)
            self.assertEqual(need_result, tr1_out)
348 349
        elif col_type == "reduce":
            need_result = input1 + input2
350 351 352 353 354 355 356
            # bfloat16 precision loss comes from truncating the last 16 bits of float32,
            # which sums (\sum_{i=-23}^{-8}2^{i}) to about 0.0078
            if dtype == "bfloat16":
                rtol = 8e-03
            else:
                rtol = 1e-05
            np.testing.assert_allclose(tr0_out[0], need_result, rtol=rtol)
357 358
        elif col_type == "scatter":
            need_result = input2
359 360
            need_result1 = need_result[0 : need_result.shape[0] // 2]
            need_result2 = need_result[need_result.shape[0] // 2 :]
361 362
            np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
            np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
363 364 365 366 367 368
        elif col_type == "scatter_object_list":
            need_result = input2
            need_result1 = [need_result[0 : len(need_result) // 2]]
            need_result2 = [need_result[len(need_result) // 2 :]]
            self.assertEqual(need_result1, tr0_out)
            self.assertEqual(need_result2, tr1_out)
369 370
        elif col_type == "reduce_scatter":
            need_result = input1 + input2
371 372
            need_result1 = need_result[0 : need_result.shape[0] // 2]
            need_result2 = need_result[need_result.shape[0] // 2 :]
373 374 375 376 377 378
            if dtype == "bfloat16":
                rtol = 8e-03
            else:
                rtol = 1e-05
            np.testing.assert_allclose(tr0_out[0], need_result1, rtol=rtol)
            np.testing.assert_allclose(tr1_out[0], need_result2, rtol=rtol)
379 380
        elif col_type == "allreduce":
            need_result = input1 + input2
381 382 383 384 385 386
            if dtype == "bfloat16":
                rtol = 8e-03
                atol = 8e-03
            else:
                rtol = 1e-05
                atol = 1e-05
387 388 389 390 391 392
            np.testing.assert_allclose(
                tr0_out[0], need_result, rtol=rtol, atol=atol
            )
            np.testing.assert_allclose(
                tr1_out[0], need_result, rtol=rtol, atol=atol
            )
393 394 395
        elif col_type == "parallel_embedding":
            result_data = tr0_out[0]
            np.random.seed(2020)
396
            need_result = np.random.rand(12, 8)
397 398 399
            for i in range(result_data.shape[0]):
                for j in range(result_data.shape[1]):
                    data = result_data[i][j]
400
                    np.testing.assert_allclose(
401 402
                        tr0_out[1][i][j], need_result[data], atol=1e-08
                    )
403 404 405 406 407
        elif col_type == "row_parallel_linear":
            result_data = tr0_out[0]
            np.random.seed(2020)
            weight = np.random.rand(1000, 16)
            need_result = np.matmul(input1, weight)
408 409 410
            np.testing.assert_allclose(
                result_data, need_result, rtol=1e-05, atol=1e-05
            )
411 412 413 414 415
        elif col_type == "column_parallel_linear":
            result_data = tr0_out[0]
            np.random.seed(2020)
            weight = np.random.rand(1000, 16)
            need_result = np.matmul(input1, weight)
416 417 418
            np.testing.assert_allclose(
                result_data, need_result, rtol=1e-05, atol=1e-05
            )
L
lilong12 已提交
419
        elif col_type == "alltoall":
420 421 422 423 424 425 426 427 428 429 430 431
            need_result1 = np.vstack(
                (
                    input1[0 : input1.shape[0] // 2, :],
                    input2[0 : input2.shape[0] // 2, :],
                )
            )
            need_result2 = np.vstack(
                (
                    input1[input1.shape[0] // 2 :, :],
                    input2[input2.shape[0] // 2 :, :],
                )
            )
L
lilong12 已提交
432 433
            tr0_out = np.vstack(tr0_out)
            tr1_out = np.vstack(tr1_out)
434 435 436 437 438 439
            np.testing.assert_allclose(
                tr0_out, need_result1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out, need_result2, rtol=1e-05, atol=1e-05
            )
L
lilong12 已提交
440 441
        elif col_type == "sendrecv":
            result_data = tr1_out[0]
442 443 444
            np.testing.assert_allclose(
                input1, result_data, rtol=1e-05, atol=1e-05
            )
445 446 447 448 449 450 451 452
        elif col_type == "global_gather":
            in_feat = 2
            n_expert = 2
            world_size = 2
            tot_expert = n_expert * world_size

            np.random.seed(pid0)
            local_expert_count1 = np.random.randint(
453 454
                1, 4, size=tot_expert
            ).astype("int")
455 456 457 458 459 460 461
            expert_ptr1 = np.ones(tot_expert, dtype=np.int32)
            expert_ptr1[0] = 0
            for i in range(1, tot_expert):
                expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]

            np.random.seed(pid1)
            local_expert_count2 = np.random.randint(
462 463
                1, 4, size=tot_expert
            ).astype("int")
464 465 466 467 468 469 470 471 472 473 474 475 476 477
            expert_ptr2 = np.ones(tot_expert, dtype=np.int32)
            expert_ptr2[0] = 0
            for i in range(1, tot_expert):
                expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

            global_expert_count1 = np.zeros(tot_expert).astype("int")
            global_expert_count2 = np.zeros(tot_expert).astype("int")
            global_expert_count1[0:n_expert] = local_expert_count1[0:n_expert]
            global_expert_count1[n_expert:] = local_expert_count2[0:n_expert]
            global_expert_count2[0:n_expert] = local_expert_count1[n_expert:]
            global_expert_count2[n_expert:] = local_expert_count2[n_expert:]

            np.random.seed(pid0)
            fwd_expert_count = sum(global_expert_count1).astype("int")
478 479 480
            local_input_buf1 = np.random.rand(fwd_expert_count, in_feat).astype(
                "float32"
            )
481 482
            np.random.seed(pid1)
            fwd_expert_count = sum(global_expert_count2).astype("int")
483 484 485
            local_input_buf2 = np.random.rand(fwd_expert_count, in_feat).astype(
                "float32"
            )
486 487 488 489 490 491 492 493 494
            output1 = [[], [], [], []]
            output2 = [[], [], [], []]
            send_ptr1 = 0
            send_ptr2 = 0

            for i in range(n_expert):
                for j in range(world_size):
                    idx = j * n_expert + i
                    if j == 0:
495 496 497 498 499 500
                        output1_part1 = local_input_buf1[
                            send_ptr1 : send_ptr1 + global_expert_count1[idx], :
                        ]
                        output1_part2 = local_input_buf2[
                            send_ptr2 : send_ptr2 + global_expert_count2[idx], :
                        ]
501 502 503
                        output1[i].extend(output1_part1)
                        output1[i + n_expert].extend(output1_part2)
                    else:
504 505 506 507 508 509
                        output2_part1 = local_input_buf1[
                            send_ptr1 : send_ptr1 + global_expert_count1[idx]
                        ]
                        output2_part2 = local_input_buf2[
                            send_ptr2 : send_ptr2 + global_expert_count2[idx]
                        ]
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
                        output2[i].extend(output2_part1)
                        output2[i + n_expert].extend(output2_part2)
                    send_ptr1 = send_ptr1 + global_expert_count1[idx]
                    send_ptr2 = send_ptr2 + global_expert_count2[idx]
            result1 = []
            result2 = []
            for i in range(tot_expert):
                for arr in output1[i]:
                    if arr == []:
                        continue
                    result1.append(arr)
            for i in range(tot_expert):
                for arr in output2[i]:
                    if arr == []:
                        continue
                    result2.append(arr)
            if result1 == []:
                output1 = np.array([])
            else:
529
                output1 = np.concatenate(result1, axis=0).reshape(
530 531
                    sum(local_expert_count1), in_feat
                )
532 533 534
            if result2 == []:
                output2 = np.array([])
            else:
535
                output2 = np.concatenate(result2, axis=0).reshape(
536 537
                    sum(local_expert_count2), in_feat
                )
538 539 540 541 542 543 544

            if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
                tr0_out[0] = np.array([])

            if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
                tr1_out[0] = np.array([])

545 546 547 548 549 550
            np.testing.assert_allclose(
                tr0_out[0], output1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out[0], output2, rtol=1e-05, atol=1e-05
            )
551
            if static_mode == 0:
552 553 554 555 556 557
                np.testing.assert_allclose(
                    tr0_out[1], 2 * local_input_buf1, rtol=1e-05, atol=1e-05
                )
                np.testing.assert_allclose(
                    tr1_out[1], 2 * local_input_buf2, rtol=1e-05, atol=1e-05
                )
558 559 560 561 562

        elif col_type == "global_scatter":
            np.random.seed(pid0)
            local_expert_count1 = np.random.randint(1, 4, size=4).astype("int")
            fwd_expert_count = sum(local_expert_count1)
563 564 565
            local_input_buf1 = np.random.rand(fwd_expert_count, 2).astype(
                "float32"
            )
566 567 568 569 570 571 572
            expert_ptr1 = np.ones(4, dtype=np.int32)
            expert_ptr1[0] = 0
            for i in range(1, 4):
                expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]
            np.random.seed(pid1)
            local_expert_count2 = np.random.randint(1, 4, size=4).astype("int")
            fwd_expert_count = sum(local_expert_count2)
573 574 575
            local_input_buf2 = np.random.rand(fwd_expert_count, 2).astype(
                "float32"
            )
576 577 578 579 580 581 582 583 584 585 586 587
            expert_ptr2 = np.ones(4, dtype=np.int32)
            expert_ptr2[0] = 0
            for i in range(1, 4):
                expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

            output1 = []
            output2 = []
            for i in range(2):
                for j in range(2):
                    idx = j * 2 + i
                    if j == 0:
                        # send data to 0 card
588 589 590 591 592 593 594 595 596 597 598 599
                        output1.append(
                            local_input_buf1[
                                expert_ptr1[idx] : expert_ptr1[idx]
                                + local_expert_count1[idx]
                            ]
                        )
                        output1.append(
                            local_input_buf2[
                                expert_ptr2[idx] : expert_ptr2[idx]
                                + local_expert_count2[idx]
                            ]
                        )
600
                    else:
601 602 603 604 605 606 607 608 609 610 611 612
                        output2.append(
                            local_input_buf1[
                                expert_ptr1[idx] : expert_ptr1[idx]
                                + local_expert_count1[idx]
                            ]
                        )
                        output2.append(
                            local_input_buf2[
                                expert_ptr2[idx] : expert_ptr2[idx]
                                + local_expert_count2[idx]
                            ]
                        )
613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
            if output1 == []:
                output1 = np.array([])
            else:
                output1 = np.concatenate(output1)
            if output2 == []:
                output2 = np.array([])
            else:
                output2 = np.concatenate(output2)

            if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
                tr0_out[0] = np.array([])

            if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
                tr1_out[0] = np.array([])

628 629 630 631 632 633
            np.testing.assert_allclose(
                tr0_out[0], output1, rtol=1e-05, atol=1e-05
            )
            np.testing.assert_allclose(
                tr1_out[0], output2, rtol=1e-05, atol=1e-05
            )
634
            if static_mode == 0:
635 636 637 638 639 640
                np.testing.assert_allclose(
                    tr0_out[1], 2 * local_input_buf1, rtol=1e-05, atol=1e-05
                )
                np.testing.assert_allclose(
                    tr1_out[1], 2 * local_input_buf2, rtol=1e-05, atol=1e-05
                )
641 642
        else:
            pass