test_auto_parallel_partitioner.py 49.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import unittest.mock
from io import StringIO
import numpy as np

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
32
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
33
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
34
from paddle.distributed.auto_parallel.dist_context import DistributedContext
35 36 37
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
38
from paddle.distributed.auto_parallel.process_group import new_process_group
39 40

paddle.enable_static()
41
_global_parallel_strategy = None
42 43 44 45 46 47 48 49
_global_process_mesh = None


def get_programs(annotated_func):
    train_program = static.Program()
    start_program = static.Program()
    dist_context = DistributedContext()
    global _global_process_mesh
50
    dist_context.process_mesh = _global_process_mesh
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    train_program, start_program = annotated_func(train_program, start_program)
    complete_train_program = auto.complete_annotation(train_program,
                                                      dist_context)

    rank_id = 3
    dist_strategy = fleet.DistributedStrategy()
    partitioner = Partitioner(dist_strategy, dist_context, rank_id)
    test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward(
        complete_train_program, start_program)

    return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context


def is_all_parameters_shape_equal(prog1, prog2):

    params1 = prog1.all_parameters()
    params2 = prog2.all_parameters()
    params1.sort(key=lambda x: x.name)
    params2.sort(key=lambda x: x.name)
    shape1 = [tensor.shape for tensor in params1]
    shape2 = [tensor.shape for tensor in params2]

    if len(shape1) != len(shape2):
        return False
    for i in range(len(shape1)):
        if shape1[i] != shape2[i]:
            return False
    return True


def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):

    for i in range(len(varnames1)):
        var1 = prog1.global_block().var(varnames1[i])
        var2 = prog2.global_block().var(varnames2[i])
        if var1.shape[axis] != (var2.shape[axis] // nsplit):
            return False

    return True


def initialization_check(mode, dist_context, dist_startup_prog,
93 94
                         serial_startup_prog, var_need_broadcast, process_mesh,
                         mp_parallel_axis, dp_parallel_axis):
95
    if 'mp' in mode:
96 97
        group_ranks = _get_comm_group(
            process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3)
98 99 100 101 102 103 104 105 106 107 108 109
        mp_ring_id = new_process_group(group_ranks).id
        broadcast_ops = [
            op for op in dist_startup_prog.global_block().ops
            if (op.type == "c_broadcast" and op.desc.attr("ring_id") ==
                mp_ring_id)
        ]
        broadcast_varnames = sorted(
            [op.desc.output_arg_names()[0] for op in broadcast_ops])
        if broadcast_varnames != var_need_broadcast:
            return False

    if 'dp' in mode:
110 111
        group_ranks = _get_comm_group(
            process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3)
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
        dp_ring_id = new_process_group(group_ranks).id
        nparam = len(serial_startup_prog.all_parameters())
        nbroadcast_dp = len([
            op for op in dist_startup_prog.global_block().ops
            if (op.type == "c_broadcast" and op.desc.attr("ring_id") ==
                dp_ring_id)
        ])
        if nparam != nbroadcast_dp:
            return False

    if "dp" in mode and 'mp' in mode:
        nbroadcast = len([
            op for op in dist_startup_prog.global_block().ops
            if op.type == "c_broadcast"
        ])
        if len(var_need_broadcast) + nbroadcast_dp != nbroadcast:
            return False

    return True


133 134 135
def get_input_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.input_arg_names()
    var = main_program.global_block().var(varname[0])
136
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
137 138 139 140 141 142
    return dist_attr


def get_output_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.output_arg_names()
    var = main_program.global_block().var(varname[0])
143
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
144 145 146 147 148
    return dist_attr


def check_equal_var_dist_attr(serial_dist_attr, dist_attr):
    equal = True
149 150
    if serial_dist_attr.process_mesh != dist_attr.process_mesh or \
        serial_dist_attr.dims_mapping != dist_attr.dims_mapping:
151 152 153 154 155 156 157 158
        equal = False
    return equal


def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops,
                             dist_op_idx):
    equal = True
    # get serial op's process_mesh and impl_idx
159 160 161
    serial_op_dist_attr = dist_context.get_op_dist_attr_for_program(serial_op)
    serial_process_mesh = serial_op_dist_attr.process_mesh
    serial_impl_idx = serial_op_dist_attr.impl_idx
162 163 164

    # check dist_attr between serial op and dist op
    for i in dist_op_idx:
165
        op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_ops[i])
166 167
        for in_varname in dist_ops[i].desc.input_arg_names():
            in_var = dist_main_prog.global_block().var(in_varname)
168
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
169
                in_var)
170
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
171 172 173 174 175 176
            in_var_dims_mapping = op_dist_attr.get_input_dims_mapping(
                in_varname)
            if tensor_dims_mapping != in_var_dims_mapping:
                equal = False
        for out_varname in dist_ops[i].desc.output_arg_names():
            out_var = dist_main_prog.global_block().var(out_varname)
177
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
178
                out_var)
179
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
180 181 182 183
            out_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
                out_varname)
            if tensor_dims_mapping != out_var_dims_mapping:
                equal = False
184 185
        dist_op_process_mesh = op_dist_attr.process_mesh
        dist_op_impl_idx = op_dist_attr.impl_idx
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        if serial_op.desc.id() == dist_ops[i].desc.id() or \
            serial_process_mesh != dist_op_process_mesh or \
            serial_impl_idx != dist_op_impl_idx:
            equal = False

    return equal


def distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                       dist_context, serial_op_idx,
                                       dist_op_idx):

    equal = True
    serial_ops = serial_main_prog.global_block().ops
    dist_ops = dist_main_prog.global_block().ops

    for i in range(len(serial_op_idx)):
        serial_op = serial_ops[serial_op_idx[i]]
        dist_op_0 = dist_ops[dist_op_idx[i][0]]
        if dist_op_0.type == "c_identity":
            # serial op input's dist_attr
            serial_in_dist_attr = get_input_var_dist_attr(
                serial_op, serial_main_prog, dist_context)
            # c_identity output's(new var) dist_attr
            identity_out_dist_attr = get_output_var_dist_attr(
                dist_op_0, dist_main_prog, dist_context)
            # check var dist_attr
            equal = check_equal_var_dist_attr(serial_in_dist_attr,
                                              identity_out_dist_attr)
        else:
            # serial op output's dist_attr
            serial_out_dist_attr = get_output_var_dist_attr(
                serial_op, serial_main_prog, dist_context)
            # dist op output's(new var) dist_attr
            out_dist_attr = get_output_var_dist_attr(dist_op_0, dist_main_prog,
                                                     dist_context)
            # check var dist_attr
            equal = check_equal_var_dist_attr(serial_out_dist_attr,
                                              out_dist_attr)

        # check op's dist_attr 
        equal = check_equal_dist_op_attr(dist_context, dist_main_prog,
                                         serial_op, dist_ops, dist_op_idx[i])

    return equal


def distributed_attr_check_for_program(dist_main_prog, dist_context):
    have_dist_attr = True
    for block in dist_main_prog.blocks:
        for tensor in block.vars.values():
237
            var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
238 239 240 241 242
                tensor)
            if var_dist_attr is None:
                have_dist_attr = False

        for op in block.ops:
243
            op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
244 245 246 247 248 249
            if op_dist_attr is None:
                have_dist_attr = False

    return have_dist_attr


250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
class MLPLayer(nn.Layer):
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=initializer_range))
        bias_attr = None

        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
271
        if _global_parallel_strategy == "mp":
272
            auto.shard_tensor(
273 274 275 276 277
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
278
            auto.shard_tensor(
279 280 281 282 283
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
284
        elif _global_parallel_strategy == "dp_mp":
285
            auto.shard_tensor(
286 287 288 289 290
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
291
            auto.shard_tensor(
292 293 294 295 296
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
297 298
        else:
            auto.shard_tensor(
299 300 301 302 303
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, -1]
                })
304
            auto.shard_tensor(
305 306 307 308 309
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, -1]
                })
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32')

331
        if _global_parallel_strategy == "dp":
332
            auto.shard_tensor(
333 334 335 336 337
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
338
        elif _global_parallel_strategy == "dp_mp":
339
            auto.shard_tensor(
340 341 342 343 344
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
345 346 347 348 349 350 351 352 353 354 355 356

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoPartitioner(unittest.TestCase):
    def test_mlp_dp(self):
357 358
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
359
        global _global_process_mesh
360
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # parameter should not be partitioned
        self.assertTrue(
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
        self.assertTrue(
            is_all_parameters_shape_equal(serial_startup_prog,
                                          dist_startup_prog))

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

        # parameter initialization 
        var_need_broadcast = []
        self.assertTrue(
382 383 384 385 386 387 388 389 390
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=None,
                dp_parallel_axis=0))
391 392

    def test_mlp_mp(self):
393 394
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
395
        global _global_process_mesh
396
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_1.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
425 426
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
427 428 429 430 431 432 433
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization 
        var_need_broadcast = sorted(
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
        self.assertTrue(
434 435 436 437 438 439 440 441 442
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=0,
                dp_parallel_axis=None))
443

444 445 446 447 448 449 450 451 452 453 454
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

455
    def test_mlp_dp_mp(self):
456 457
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
458 459
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
460
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_1.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
489 490
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
491 492 493 494 495 496 497
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization
        var_need_broadcast = sorted(
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
        self.assertTrue(
498 499 500 501 502 503 504 505 506
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=1,
                dp_parallel_axis=0))
507

508 509 510 511 512 513 514 515 516 517 518
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555

class AttentionLayer(nn.Layer):
    def __init__(self,
                 hidden_size=1024,
                 sequence_len=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=initializer_range))
        bias_attr = None

        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)

    def forward(self, input):
556
        if _global_parallel_strategy == "dp":
557
            auto.shard_tensor(
558 559 560 561 562
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
563
        elif _global_parallel_strategy == "dp_mp":
564
            auto.shard_tensor(
565 566 567 568 569
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
570 571 572 573 574 575 576 577

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

578
        if _global_parallel_strategy == "mp":
579
            auto.shard_tensor(
580 581 582 583 584
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
585
            auto.shard_tensor(
586 587 588 589 590
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
591
            auto.shard_tensor(
592 593 594 595 596
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
597
        elif _global_parallel_strategy == "dp_mp":
598
            auto.shard_tensor(
599 600 601 602 603
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
604
            auto.shard_tensor(
605 606 607 608 609
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
610
            auto.shard_tensor(
611 612 613 614 615
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
646
        if _global_parallel_strategy == "mp":
647
            auto.shard_tensor(
648 649 650 651 652
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
653
        elif _global_parallel_strategy == "dp_mp":
654
            auto.shard_tensor(
655 656 657 658 659
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687

        return out


def attn_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32')
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoPartitioner(unittest.TestCase):
    def test_attn_dp(self):
688 689
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
690
        global _global_process_mesh
691
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)
        # parameter should not be partitioned
        self.assertTrue(
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
        self.assertTrue(
            is_all_parameters_shape_equal(serial_startup_prog,
                                          dist_startup_prog))

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

        # parameter initialization 
        var_need_broadcast = []
        self.assertTrue(
712 713 714 715 716 717 718 719 720
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=None,
                dp_parallel_axis=0))
721 722

    def test_attn_mp(self):
723 724
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
725
        global _global_process_mesh
726
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_3.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
756 757 758 759 760 761
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
            'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add'
762 763 764 765 766 767
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization 
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
768 769 770 771 772 773 774 775 776
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=0,
                dp_parallel_axis=None))
777

778 779 780 781 782 783 784 785 786 787 788
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

789
    def test_attn_dp_mp(self):
790 791
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
792 793
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
794
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_3.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
824 825 826 827 828 829
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
            'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add'
830 831 832 833 834 835
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization 
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
836 837 838 839 840 841 842 843 844
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=1,
                dp_parallel_axis=0))
845

846 847 848 849 850 851 852 853 854 855 856
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927

class DecoderLayer(nn.Layer):
    def __init__(self,
                 vocab_size=32768,
                 hidden_size=1024,
                 sequence_len=512,
                 max_position_embeddings=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range)))
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range)))

        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
928
        if _global_parallel_strategy == "dp":
929
            auto.shard_tensor(
930 931 932 933 934
                input_ids,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
935
        elif _global_parallel_strategy == "dp_mp":
936
            auto.shard_tensor(
937 938 939 940 941
                input_ids,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
942 943 944 945

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

946
        if _global_parallel_strategy == "mp":
947 948
            auto.shard_tensor(
                self.word_embeddings.weight,
949 950 951 952
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
953
        elif _global_parallel_strategy == "dp_mp":
954 955
            auto.shard_tensor(
                self.word_embeddings.weight,
956 957 958 959
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
960 961 962 963 964 965 966 967 968 969 970 971 972 973 974

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
        target = self.norm(embeddings)

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

975
        if _global_parallel_strategy == "mp":
976
            auto.shard_tensor(
977 978 979 980 981
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
982
            auto.shard_tensor(
983 984 985 986 987
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
988
            auto.shard_tensor(
989 990 991 992 993
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
994
        elif _global_parallel_strategy == "dp_mp":
995
            auto.shard_tensor(
996 997 998 999 1000
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
1001
            auto.shard_tensor(
1002 1003 1004 1005 1006
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
1007
            auto.shard_tensor(
1008 1009 1010 1011 1012
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

1044
        if _global_parallel_strategy == "mp":
1045
            auto.shard_tensor(
1046 1047 1048 1049 1050
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
1051
        elif _global_parallel_strategy == "dp_mp":
1052
            auto.shard_tensor(
1053 1054 1055 1056 1057
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
1058 1059 1060
        else:
            auto.shard_tensor(
                self.out_proj.weight,
1061 1062 1063 1064
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, -1]
                })
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
        out0 = self.norm(residual)

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

1077
        if _global_parallel_strategy == "mp":
1078
            auto.shard_tensor(
1079 1080 1081 1082 1083
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
1084
            auto.shard_tensor(
1085 1086 1087 1088 1089
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
1090
        elif _global_parallel_strategy == "dp_mp":
1091
            auto.shard_tensor(
1092 1093 1094 1095 1096
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
1097
            auto.shard_tensor(
1098 1099 1100 1101 1102
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
        position_ids = static.data(
            name="position_ids",
            shape=[batch_size, sequence_len],
            dtype='int64')
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerPartitioner(unittest.TestCase):
    def test_decoder_dp_mp(self):
1137 1138
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
1139 1140
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
1141
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            decoder_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = [
            'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = [
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'c_embedding', 'c_allreduce_sum', 'lookup_table_v2',
1178 1179 1180 1181 1182 1183 1184 1185 1186 1187
            'elementwise_add', 'dropout', 'layer_norm', 'c_identity',
            'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
            'c_identity', 'matmul_v2', 'elementwise_add', 'c_identity',
            'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
            'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout',
            'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add', 'dropout', 'elementwise_add',
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout',
            'elementwise_add'
1188 1189 1190 1191 1192 1193 1194 1195 1196
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization 
        var_need_broadcast = sorted([
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ])
        self.assertTrue(
1197 1198 1199 1200 1201 1202 1203 1204 1205
            initialization_check(
                _global_parallel_strategy,
                dist_context,
                dist_startup_prog,
                serial_startup_prog,
                var_need_broadcast,
                _global_process_mesh,
                mp_parallel_axis=1,
                dp_parallel_axis=0))
1206

1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr
        serial_op_idx = [0, 5, 9, 11, 23, 28, 31]
        dist_op_idx = [[0, 1], [6, 7], [11, 12], [14, 15], [27, 28], [33, 34],
                       [37, 38]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

1219
    def test_decoder_noparallel(self):
1220 1221
        global _global_parallel_strategy
        _global_parallel_strategy = "None"
1222 1223
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
1224
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            decoder_pretrain_forward)

        # param should be partition
        nrank = 1
        # col parallel
        weights = [
            'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = [
            'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = [
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout',
1261 1262 1263 1264 1265 1266 1267
            'layer_norm', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'matmul_v2', 'elementwise_add', 'matmul_v2',
            'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
            'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
            'transpose2', 'reshape2', 'matmul_v2', 'elementwise_add', 'dropout',
            'elementwise_add', 'layer_norm', 'matmul_v2', 'elementwise_add',
            'gelu', 'matmul_v2', 'elementwise_add', 'dropout', 'elementwise_add'
1268 1269 1270 1271 1272 1273 1274 1275 1276
        ]
        self.assertTrue(dist_ops == ref_ops)
        dist_ops = dist_startup_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'gaussian_random', 'gaussian_random', 'gaussian_random',
            'fill_constant', 'gaussian_random', 'fill_constant',
            'gaussian_random', 'fill_constant', 'gaussian_random',
            'fill_constant', 'gaussian_random', 'fill_constant',
1277 1278 1279 1280 1281 1282 1283 1284 1285 1286
            'gaussian_random', 'fill_constant', 'fill_constant',
            'fill_constant', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast'
1287 1288 1289 1290 1291 1292
        ]
        self.assertTrue(dist_ops == ref_ops)


if __name__ == "__main__":
    unittest.main()