test_auto_parallel_partitioner.py 48.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest.mock

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
25
from paddle.distributed.fleet import auto
26
from paddle.distributed.auto_parallel.completion import Completer
27
from paddle.distributed.auto_parallel.dist_context import DistributedContext
28 29 30
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
31
from paddle.distributed.auto_parallel.process_group import new_process_group
32 33

paddle.enable_static()
34
_global_parallel_strategy = None
35 36 37 38 39 40 41 42
_global_process_mesh = None


def get_programs(annotated_func):
    train_program = static.Program()
    start_program = static.Program()
    dist_context = DistributedContext()
    global _global_process_mesh
43
    dist_context.process_mesh = _global_process_mesh
44
    train_program, start_program = annotated_func(train_program, start_program)
45 46
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
47 48
        train_program
    )
49
    dist_context.block_state.parse_forward_blocks(complete_train_program)
50 51 52

    rank_id = 3
    dist_strategy = fleet.DistributedStrategy()
53
    partitioner = Partitioner(dist_context, rank_id)
54 55 56 57 58
    (
        test_auto_parallel_dist_main_prog,
        test_auto_parallel_dist_startup_prog,
        _,
    ) = partitioner.partition(complete_train_program, start_program, [])
59

60 61 62 63 64 65 66
    return (
        complete_train_program,
        start_program,
        test_auto_parallel_dist_main_prog,
        test_auto_parallel_dist_startup_prog,
        dist_context,
    )
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96


def is_all_parameters_shape_equal(prog1, prog2):

    params1 = prog1.all_parameters()
    params2 = prog2.all_parameters()
    params1.sort(key=lambda x: x.name)
    params2.sort(key=lambda x: x.name)
    shape1 = [tensor.shape for tensor in params1]
    shape2 = [tensor.shape for tensor in params2]

    if len(shape1) != len(shape2):
        return False
    for i in range(len(shape1)):
        if shape1[i] != shape2[i]:
            return False
    return True


def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):

    for i in range(len(varnames1)):
        var1 = prog1.global_block().var(varnames1[i])
        var2 = prog2.global_block().var(varnames2[i])
        if var1.shape[axis] != (var2.shape[axis] // nsplit):
            return False

    return True


97 98 99 100 101 102 103 104 105 106
def initialization_check(
    mode,
    dist_context,
    dist_startup_prog,
    serial_startup_prog,
    var_need_broadcast,
    process_mesh,
    mp_parallel_axis,
    dp_parallel_axis,
):
107
    if 'mp' in mode:
108 109 110
        group_ranks = _get_comm_group(
            process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3
        )
111 112
        mp_ring_id = new_process_group(group_ranks).id
        broadcast_ops = [
113 114 115 116 117 118
            op
            for op in dist_startup_prog.global_block().ops
            if (
                op.type == "c_broadcast"
                and op.desc.attr("ring_id") == mp_ring_id
            )
119 120
        ]
        broadcast_varnames = sorted(
121 122
            [op.desc.output_arg_names()[0] for op in broadcast_ops]
        )
123 124 125 126
        if broadcast_varnames != var_need_broadcast:
            return False

    if 'dp' in mode:
127 128 129
        group_ranks = _get_comm_group(
            process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3
        )
130 131
        dp_ring_id = new_process_group(group_ranks).id
        nparam = len(serial_startup_prog.all_parameters())
132 133 134 135 136 137 138 139 140 141
        nbroadcast_dp = len(
            [
                op
                for op in dist_startup_prog.global_block().ops
                if (
                    op.type == "c_broadcast"
                    and op.desc.attr("ring_id") == dp_ring_id
                )
            ]
        )
142 143 144 145
        if nparam != nbroadcast_dp:
            return False

    if "dp" in mode and 'mp' in mode:
146 147 148 149 150 151 152
        nbroadcast = len(
            [
                op
                for op in dist_startup_prog.global_block().ops
                if op.type == "c_broadcast"
            ]
        )
153 154 155 156 157 158
        if len(var_need_broadcast) + nbroadcast_dp != nbroadcast:
            return False

    return True


159 160 161
def get_input_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.input_arg_names()
    var = main_program.global_block().var(varname[0])
162
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
163 164 165 166 167 168
    return dist_attr


def get_output_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.output_arg_names()
    var = main_program.global_block().var(varname[0])
169
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
170 171 172 173 174
    return dist_attr


def check_equal_var_dist_attr(serial_dist_attr, dist_attr):
    equal = True
175 176 177 178
    if (
        serial_dist_attr.process_mesh != dist_attr.process_mesh
        or serial_dist_attr.dims_mapping != dist_attr.dims_mapping
    ):
179 180 181 182
        equal = False
    return equal


183 184 185
def check_equal_dist_op_attr(
    dist_context, dist_main_prog, serial_op, dist_ops, dist_op_idx
):
186 187
    equal = True
    # get serial op's process_mesh and impl_idx
188 189 190
    serial_op_dist_attr = dist_context.get_op_dist_attr_for_program(serial_op)
    serial_process_mesh = serial_op_dist_attr.process_mesh
    serial_impl_idx = serial_op_dist_attr.impl_idx
191 192 193

    # check dist_attr between serial op and dist op
    for i in dist_op_idx:
194
        op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_ops[i])
195 196
        for in_varname in dist_ops[i].desc.input_arg_names():
            in_var = dist_main_prog.global_block().var(in_varname)
197
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
198 199
                in_var
            )
200
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
201
            in_var_dims_mapping = op_dist_attr.get_input_dims_mapping(
202 203
                in_varname
            )
204 205 206 207
            if tensor_dims_mapping != in_var_dims_mapping:
                equal = False
        for out_varname in dist_ops[i].desc.output_arg_names():
            out_var = dist_main_prog.global_block().var(out_varname)
208
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
209 210
                out_var
            )
211
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
212
            out_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
213 214
                out_varname
            )
215 216
            if tensor_dims_mapping != out_var_dims_mapping:
                equal = False
217 218
        dist_op_process_mesh = op_dist_attr.process_mesh
        dist_op_impl_idx = op_dist_attr.impl_idx
219 220 221 222 223
        if (
            serial_op.desc.id() == dist_ops[i].desc.id()
            or serial_process_mesh != dist_op_process_mesh
            or serial_impl_idx != dist_op_impl_idx
        ):
224 225 226 227 228
            equal = False

    return equal


229 230 231
def distributed_attr_check_for_dist_op(
    serial_main_prog, dist_main_prog, dist_context, serial_op_idx, dist_op_idx
):
232 233 234 235 236 237 238 239 240 241 242

    equal = True
    serial_ops = serial_main_prog.global_block().ops
    dist_ops = dist_main_prog.global_block().ops

    for i in range(len(serial_op_idx)):
        serial_op = serial_ops[serial_op_idx[i]]
        dist_op_0 = dist_ops[dist_op_idx[i][0]]
        if dist_op_0.type == "c_identity":
            # serial op input's dist_attr
            serial_in_dist_attr = get_input_var_dist_attr(
243 244
                serial_op, serial_main_prog, dist_context
            )
245 246
            # c_identity output's(new var) dist_attr
            identity_out_dist_attr = get_output_var_dist_attr(
247 248
                dist_op_0, dist_main_prog, dist_context
            )
249
            # check var dist_attr
250 251 252
            equal = check_equal_var_dist_attr(
                serial_in_dist_attr, identity_out_dist_attr
            )
253 254 255
        else:
            # serial op output's dist_attr
            serial_out_dist_attr = get_output_var_dist_attr(
256 257
                serial_op, serial_main_prog, dist_context
            )
258
            # dist op output's(new var) dist_attr
259 260 261
            out_dist_attr = get_output_var_dist_attr(
                dist_op_0, dist_main_prog, dist_context
            )
262
            # check var dist_attr
263 264 265
            equal = check_equal_var_dist_attr(
                serial_out_dist_attr, out_dist_attr
            )
266

267
        # check op's dist_attr
268 269 270
        equal = check_equal_dist_op_attr(
            dist_context, dist_main_prog, serial_op, dist_ops, dist_op_idx[i]
        )
271 272 273 274 275 276 277

    return equal


def distributed_attr_check_for_program(dist_main_prog, dist_context):
    have_dist_attr = True
    for block in dist_main_prog.blocks:
278 279
        for var in block.vars.values():
            var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
280 281 282 283
            if var_dist_attr is None:
                have_dist_attr = False

        for op in block.ops:
284
            op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
285 286 287 288 289 290
            if op_dist_attr is None:
                have_dist_attr = False

    return have_dist_attr


291
class MLPLayer(nn.Layer):
292 293 294 295 296 297 298
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
299
        super().__init__()
300 301
        d_model = hidden_size
        dim_feedforward = intermediate_size
302
        weight_attr = paddle.ParamAttr(
303 304
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
305 306
        bias_attr = None

307 308 309 310 311 312
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
313 314 315 316
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
317
        if _global_parallel_strategy in ["mp", "dp_mp"]:
318 319 320 321 322 323 324 325 326 327
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
328
        else:
329 330 331 332 333 334 335 336 337 338
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, None],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, None],
            )
339 340 341 342 343 344 345 346 347 348 349

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
350 351 352
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
353 354 355
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
356 357 358 359 360
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
361

362
        if _global_parallel_strategy in ["dp", "dp_mp"]:
363 364 365 366 367 368 369 370 371 372 373 374
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
375 376 377 378 379 380
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoPartitioner(unittest.TestCase):
    def test_mlp_dp(self):
381 382
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
383
        global _global_process_mesh
384 385 386 387 388 389 390 391 392 393 394
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )

        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(mlp_pretrain_forward)
395 396 397

        # parameter should not be partitioned
        self.assertTrue(
398 399
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)
        )
400
        self.assertTrue(
401 402 403 404
            is_all_parameters_shape_equal(
                serial_startup_prog, dist_startup_prog
            )
        )
405 406 407 408 409 410 411 412

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

413
        # parameter initialization
414 415
        var_need_broadcast = []
        self.assertTrue(
416 417 418 419 420 421 422 423 424 425 426
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=None,
                dp_parallel_axis=0,
            )
        )
427 428

    def test_mlp_mp(self):
429 430
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
431
        global _global_process_mesh
432 433 434 435 436 437 438 439 440 441
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(mlp_pretrain_forward)
442 443 444 445 446 447

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
448 449 450 451
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 1, nrank
            )
        )
452 453
        weights = ['linear_0.b_0']
        self.assertTrue(
454 455 456 457
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
458 459 460
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
461 462 463 464
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
465 466
        weights = ['linear_1.b_0']
        self.assertTrue(
467 468 469 470
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, 1
            )
        )
471 472 473 474 475

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
476 477 478 479 480 481 482 483 484
            'layer_norm',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'gelu',
            'matmul_v2',
            'c_allreduce_sum',
            'elementwise_add',
            'dropout',
485 486 487
        ]
        self.assertTrue(dist_ops == ref_ops)

488
        # parameter initialization
489
        var_need_broadcast = sorted(
490 491
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']
        )
492
        self.assertTrue(
493 494 495 496 497 498 499 500 501 502 503
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=0,
                dp_parallel_axis=None,
            )
        )
504

505 506
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
507 508
            distributed_attr_check_for_program(dist_main_prog, dist_context)
        )
509 510 511 512
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
513 514 515 516 517 518 519 520
            distributed_attr_check_for_dist_op(
                serial_main_prog,
                dist_main_prog,
                dist_context,
                serial_op_idx,
                dist_op_idx,
            )
        )
521

522
    def test_mlp_dp_mp(self):
523 524
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
525
        global _global_process_mesh
526 527 528 529 530 531 532 533 534 535
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(mlp_pretrain_forward)
536 537 538 539 540 541

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
542 543 544 545
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 1, nrank
            )
        )
546 547
        weights = ['linear_0.b_0']
        self.assertTrue(
548 549 550 551
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
552 553 554
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
555 556 557 558
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
559 560
        weights = ['linear_1.b_0']
        self.assertTrue(
561 562 563 564
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, 1
            )
        )
565 566 567 568 569

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
570 571 572 573 574 575 576 577 578
            'layer_norm',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'gelu',
            'matmul_v2',
            'c_allreduce_sum',
            'elementwise_add',
            'dropout',
579 580 581 582 583
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization
        var_need_broadcast = sorted(
584 585
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']
        )
586
        self.assertTrue(
587 588 589 590 591 592 593 594 595 596 597
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=1,
                dp_parallel_axis=0,
            )
        )
598

599 600
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
601 602
            distributed_attr_check_for_program(dist_main_prog, dist_context)
        )
603 604 605 606
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
607 608 609 610 611 612 613 614
            distributed_attr_check_for_dist_op(
                serial_main_prog,
                dist_main_prog,
                dist_context,
                serial_op_idx,
                dist_op_idx,
            )
        )
615

616 617

class AttentionLayer(nn.Layer):
618 619 620 621 622 623 624 625 626
    def __init__(
        self,
        hidden_size=1024,
        sequence_len=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
627
        super().__init__()
628 629 630 631 632 633 634
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
635 636 637
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
638 639 640 641
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
642
        weight_attr = paddle.ParamAttr(
643 644
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
645 646
        bias_attr = None

647 648 649 650 651 652 653 654 655 656 657 658
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
659 660

    def forward(self, input):
661
        if _global_parallel_strategy in ["dp", "dp_mp"]:
662 663 664 665 666
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )
667 668 669 670 671 672 673 674

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

675
        if _global_parallel_strategy in ["mp", "dp_mp"]:
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
691 692 693 694 695 696 697

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
698 699 700
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
        )
701 702 703 704 705 706 707

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
708 709 710 711 712 713
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
714 715 716 717 718 719 720 721 722

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
723 724

        if _global_parallel_strategy in ["mp", "dp_mp"]:
725 726 727 728 729
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
730 731 732 733 734

        return out


def attn_pretrain_forward(train_program, start_program):
735 736 737
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
738 739 740
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
741 742 743 744 745 746 747 748 749 750 751 752 753
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
754 755 756 757 758 759 760
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoPartitioner(unittest.TestCase):
    def test_attn_dp(self):
761 762
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
763
        global _global_process_mesh
764 765 766 767 768 769 770 771 772 773 774
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )

        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(attn_pretrain_forward)
775 776
        # parameter should not be partitioned
        self.assertTrue(
777 778
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)
        )
779
        self.assertTrue(
780 781 782 783
            is_all_parameters_shape_equal(
                serial_startup_prog, dist_startup_prog
            )
        )
784 785 786 787 788 789 790 791

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

792
        # parameter initialization
793 794
        var_need_broadcast = []
        self.assertTrue(
795 796 797 798 799 800 801 802 803 804 805
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=None,
                dp_parallel_axis=0,
            )
        )
806 807

    def test_attn_mp(self):
808 809
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
810
        global _global_process_mesh
811 812 813 814 815 816 817 818 819 820 821
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )

        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(attn_pretrain_forward)
822 823 824 825 826 827

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
828 829 830 831
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 1, nrank
            )
        )
832 833
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
834 835 836 837
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
838 839 840
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
841 842 843 844
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
845 846
        weights = ['linear_3.b_0']
        self.assertTrue(
847 848 849 850
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, 1
            )
        )
851 852 853 854 855

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'reshape2',
            'transpose2',
            'matmul',
            'softmax',
            'dropout',
            'matmul_v2',
            'transpose2',
            'reshape2',
            'matmul_v2',
            'c_allreduce_sum',
            'elementwise_add',
880 881 882
        ]
        self.assertTrue(dist_ops == ref_ops)

883
        # parameter initialization
884 885
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
886 887 888 889 890 891 892 893 894 895 896
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=0,
                dp_parallel_axis=None,
            )
        )
897

898 899
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
900 901
            distributed_attr_check_for_program(dist_main_prog, dist_context)
        )
902 903 904 905
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
906 907 908 909 910 911 912 913
            distributed_attr_check_for_dist_op(
                serial_main_prog,
                dist_main_prog,
                dist_context,
                serial_op_idx,
                dist_op_idx,
            )
        )
914

915
    def test_attn_dp_mp(self):
916 917
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
918
        global _global_process_mesh
919 920 921 922 923 924 925 926 927 928 929
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )

        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(attn_pretrain_forward)
930 931 932 933 934 935

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
936 937 938 939
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 1, nrank
            )
        )
940 941
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
942 943 944 945
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
946 947 948
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
949 950 951 952
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
953 954
        weights = ['linear_3.b_0']
        self.assertTrue(
955 956 957 958
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, 1
            )
        )
959 960 961 962 963

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'reshape2',
            'transpose2',
            'matmul',
            'softmax',
            'dropout',
            'matmul_v2',
            'transpose2',
            'reshape2',
            'matmul_v2',
            'c_allreduce_sum',
            'elementwise_add',
988 989 990
        ]
        self.assertTrue(dist_ops == ref_ops)

991
        # parameter initialization
992 993
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
994 995 996 997 998 999 1000 1001 1002 1003 1004
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=1,
                dp_parallel_axis=0,
            )
        )
1005

1006 1007
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
1008 1009
            distributed_attr_check_for_program(dist_main_prog, dist_context)
        )
1010 1011 1012 1013
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
1014 1015 1016 1017 1018 1019 1020 1021
            distributed_attr_check_for_dist_op(
                serial_main_prog,
                dist_main_prog,
                dist_context,
                serial_op_idx,
                dist_op_idx,
            )
        )
1022

1023 1024

class DecoderLayer(nn.Layer):
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
    def __init__(
        self,
        vocab_size=32768,
        hidden_size=1024,
        sequence_len=512,
        max_position_embeddings=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
1036
        super().__init__()
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
1051 1052 1053
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
1054 1055 1056
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
1057 1058 1059 1060 1061 1062 1063
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
1064 1065 1066
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
1067 1068 1069 1070 1071 1072 1073
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
1074

1075 1076 1077 1078 1079
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
1080
        bias_attr = None
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
1093 1094 1095 1096

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
1097 1098 1099 1100 1101
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
1102
        bias_attr = None
1103 1104 1105 1106 1107 1108
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
1109 1110 1111 1112 1113 1114
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
1115
        if _global_parallel_strategy in ["dp", "dp_mp"]:
1116 1117 1118 1119 1120
            auto.shard_tensor(
                input_ids,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None],
            )
1121 1122 1123 1124

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

1125
        if _global_parallel_strategy in ["mp", "dp_mp"]:
1126 1127 1128 1129 1130
            auto.shard_tensor(
                self.word_embeddings.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
        target = self.norm(embeddings)

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

1146
        if _global_parallel_strategy in ["mp", "dp_mp"]:
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
1162 1163 1164 1165 1166 1167 1168

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
1169 1170 1171
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
        )
1172 1173 1174 1175 1176 1177 1178

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
1179 1180 1181 1182 1183 1184
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

1195
        if _global_parallel_strategy in ["mp", "dp_mp"]:
1196 1197 1198 1199 1200
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
1201
        else:
1202 1203 1204 1205 1206
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, None],
            )
1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
        out0 = self.norm(residual)

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

1219
        if _global_parallel_strategy in ["mp", "dp_mp"]:
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
1230 1231 1232 1233 1234 1235 1236

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
1237 1238 1239
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
1240 1241 1242
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        position_ids = static.data(
            name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
1259 1260 1261 1262 1263 1264 1265
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerPartitioner(unittest.TestCase):
    def test_decoder_dp_mp(self):
1266 1267
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
1268
        global _global_process_mesh
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(decoder_pretrain_forward)
1279 1280 1281 1282 1283

        # param should be partition
        nrank = 4
        # col parallel
        weights = [
1284 1285 1286 1287
            'linear_0.w_0',
            'linear_1.w_0',
            'linear_2.w_0',
            'linear_4.w_0',
1288 1289
        ]
        self.assertTrue(
1290 1291 1292 1293
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 1, nrank
            )
        )
1294
        weights = [
1295 1296 1297 1298
            'linear_0.b_0',
            'linear_1.b_0',
            'linear_2.b_0',
            'linear_4.b_0',
1299 1300
        ]
        self.assertTrue(
1301 1302 1303 1304
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
1305 1306 1307
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
1308 1309 1310 1311
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
1312
        weights = [
1313 1314 1315 1316 1317
            'linear_3.b_0',
            'pos_embeddings',
            'layer_norm_0.b_0',
            'layer_norm_0.w_0',
            'linear_5.b_0',
1318 1319
        ]
        self.assertTrue(
1320 1321 1322 1323
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, 1
            )
        )
1324 1325 1326 1327 1328

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
            'c_embedding',
            'c_allreduce_sum',
            'lookup_table_v2',
            'elementwise_add',
            'dropout',
            'layer_norm',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'reshape2',
            'transpose2',
            'matmul',
            'softmax',
            'dropout',
            'matmul_v2',
            'transpose2',
            'reshape2',
            'matmul_v2',
            'c_allreduce_sum',
            'elementwise_add',
            'dropout',
            'elementwise_add',
            'layer_norm',
            'c_identity',
            'matmul_v2',
            'elementwise_add',
            'gelu',
            'matmul_v2',
            'c_allreduce_sum',
            'elementwise_add',
            'dropout',
            'elementwise_add',
1371 1372 1373
        ]
        self.assertTrue(dist_ops == ref_ops)

1374
        # parameter initialization
1375 1376 1377 1378 1379 1380 1381 1382 1383
        var_need_broadcast = sorted(
            [
                'linear_3.b_0',
                'pos_embeddings',
                'layer_norm_0.b_0',
                'layer_norm_0.w_0',
                'linear_5.b_0',
            ]
        )
1384
        self.assertTrue(
1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=1,
                dp_parallel_axis=0,
            )
        )
1396

1397 1398
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
1399 1400
            distributed_attr_check_for_program(dist_main_prog, dist_context)
        )
1401 1402
        # check distribured attr
        serial_op_idx = [0, 5, 9, 11, 23, 28, 31]
1403 1404 1405 1406 1407 1408 1409 1410 1411
        dist_op_idx = [
            [0, 1],
            [6, 7],
            [11, 12],
            [14, 15],
            [27, 28],
            [33, 34],
            [37, 38],
        ]
1412
        self.assertTrue(
1413 1414 1415 1416 1417 1418 1419 1420
            distributed_attr_check_for_dist_op(
                serial_main_prog,
                dist_main_prog,
                dist_context,
                serial_op_idx,
                dist_op_idx,
            )
        )
1421

1422
    def test_decoder_noparallel(self):
1423 1424
        global _global_parallel_strategy
        _global_parallel_strategy = "None"
1425
        global _global_process_mesh
1426 1427 1428 1429 1430 1431 1432 1433 1434 1435
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
        )
        (
            serial_main_prog,
            serial_startup_prog,
            dist_main_prog,
            dist_startup_prog,
            dist_context,
        ) = get_programs(decoder_pretrain_forward)
1436 1437 1438 1439 1440

        # param should be partition
        nrank = 1
        # col parallel
        weights = [
1441 1442 1443 1444
            'linear_0.w_0',
            'linear_1.w_0',
            'linear_2.w_0',
            'linear_4.w_0',
1445 1446
        ]
        self.assertTrue(
1447 1448 1449 1450
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 1, nrank
            )
        )
1451
        weights = [
1452 1453 1454 1455
            'linear_0.b_0',
            'linear_1.b_0',
            'linear_2.b_0',
            'linear_4.b_0',
1456 1457
        ]
        self.assertTrue(
1458 1459 1460 1461
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
1462 1463 1464
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
1465 1466 1467 1468
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, nrank
            )
        )
1469
        weights = [
1470 1471 1472 1473 1474
            'linear_3.b_0',
            'pos_embeddings',
            'layer_norm_0.b_0',
            'layer_norm_0.w_0',
            'linear_5.b_0',
1475 1476
        ]
        self.assertTrue(
1477 1478 1479 1480
            check_tensor_split(
                dist_main_prog, weights, serial_main_prog, weights, 0, 1
            )
        )
1481 1482 1483 1484 1485

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520
            'lookup_table_v2',
            'lookup_table_v2',
            'elementwise_add',
            'dropout',
            'layer_norm',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'matmul_v2',
            'elementwise_add',
            'matmul_v2',
            'elementwise_add',
            'reshape2',
            'transpose2',
            'reshape2',
            'transpose2',
            'matmul',
            'softmax',
            'dropout',
            'matmul_v2',
            'transpose2',
            'reshape2',
            'matmul_v2',
            'elementwise_add',
            'dropout',
            'elementwise_add',
            'layer_norm',
            'matmul_v2',
            'elementwise_add',
            'gelu',
            'matmul_v2',
            'elementwise_add',
            'dropout',
            'elementwise_add',
1521 1522 1523 1524 1525
        ]
        self.assertTrue(dist_ops == ref_ops)
        dist_ops = dist_startup_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573
            'gaussian_random',
            'gaussian_random',
            'gaussian_random',
            'fill_constant',
            'gaussian_random',
            'fill_constant',
            'gaussian_random',
            'fill_constant',
            'gaussian_random',
            'fill_constant',
            'gaussian_random',
            'fill_constant',
            'gaussian_random',
            'fill_constant',
            'fill_constant',
            'fill_constant',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
            'c_broadcast',
1574 1575 1576 1577 1578 1579
        ]
        self.assertTrue(dist_ops == ref_ops)


if __name__ == "__main__":
    unittest.main()