test_auto_parallel_partitioner.py 48.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import unittest.mock
from io import StringIO
import numpy as np

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
31
from paddle.distributed.auto_parallel.completion import Completer
32
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
33
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
34
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
35
from paddle.distributed.auto_parallel.dist_context import DistributedContext
36 37 38
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
39
from paddle.distributed.auto_parallel.process_group import new_process_group
40 41

paddle.enable_static()
42
_global_parallel_strategy = None
43 44 45 46 47 48 49 50
_global_process_mesh = None


def get_programs(annotated_func):
    train_program = static.Program()
    start_program = static.Program()
    dist_context = DistributedContext()
    global _global_process_mesh
51
    dist_context.process_mesh = _global_process_mesh
52
    train_program, start_program = annotated_func(train_program, start_program)
53 54 55
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
56
    dist_context.block_state.parse_forward_blocks(complete_train_program)
57 58 59

    rank_id = 3
    dist_strategy = fleet.DistributedStrategy()
60 61 62
    partitioner = Partitioner(dist_context, rank_id)
    test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, _ = partitioner.partition(
        complete_train_program, start_program, [])
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

    return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context


def is_all_parameters_shape_equal(prog1, prog2):

    params1 = prog1.all_parameters()
    params2 = prog2.all_parameters()
    params1.sort(key=lambda x: x.name)
    params2.sort(key=lambda x: x.name)
    shape1 = [tensor.shape for tensor in params1]
    shape2 = [tensor.shape for tensor in params2]

    if len(shape1) != len(shape2):
        return False
    for i in range(len(shape1)):
        if shape1[i] != shape2[i]:
            return False
    return True


def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):

    for i in range(len(varnames1)):
        var1 = prog1.global_block().var(varnames1[i])
        var2 = prog2.global_block().var(varnames2[i])
        if var1.shape[axis] != (var2.shape[axis] // nsplit):
            return False

    return True


def initialization_check(mode, dist_context, dist_startup_prog,
96 97
                         serial_startup_prog, var_need_broadcast, process_mesh,
                         mp_parallel_axis, dp_parallel_axis):
98
    if 'mp' in mode:
99 100 101
        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, mp_parallel_axis,
                                      3)
102 103
        mp_ring_id = new_process_group(group_ranks).id
        broadcast_ops = [
104 105
            op for op in dist_startup_prog.global_block().ops if
            (op.type == "c_broadcast" and op.desc.attr("ring_id") == mp_ring_id)
106 107 108 109 110 111 112
        ]
        broadcast_varnames = sorted(
            [op.desc.output_arg_names()[0] for op in broadcast_ops])
        if broadcast_varnames != var_need_broadcast:
            return False

    if 'dp' in mode:
113 114 115
        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, dp_parallel_axis,
                                      3)
116 117 118
        dp_ring_id = new_process_group(group_ranks).id
        nparam = len(serial_startup_prog.all_parameters())
        nbroadcast_dp = len([
119 120
            op for op in dist_startup_prog.global_block().ops if
            (op.type == "c_broadcast" and op.desc.attr("ring_id") == dp_ring_id)
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        ])
        if nparam != nbroadcast_dp:
            return False

    if "dp" in mode and 'mp' in mode:
        nbroadcast = len([
            op for op in dist_startup_prog.global_block().ops
            if op.type == "c_broadcast"
        ])
        if len(var_need_broadcast) + nbroadcast_dp != nbroadcast:
            return False

    return True


136 137 138
def get_input_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.input_arg_names()
    var = main_program.global_block().var(varname[0])
139
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
140 141 142 143 144 145
    return dist_attr


def get_output_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.output_arg_names()
    var = main_program.global_block().var(varname[0])
146
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
147 148 149 150 151
    return dist_attr


def check_equal_var_dist_attr(serial_dist_attr, dist_attr):
    equal = True
152 153
    if serial_dist_attr.process_mesh != dist_attr.process_mesh or \
        serial_dist_attr.dims_mapping != dist_attr.dims_mapping:
154 155 156 157 158 159 160 161
        equal = False
    return equal


def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops,
                             dist_op_idx):
    equal = True
    # get serial op's process_mesh and impl_idx
162 163 164
    serial_op_dist_attr = dist_context.get_op_dist_attr_for_program(serial_op)
    serial_process_mesh = serial_op_dist_attr.process_mesh
    serial_impl_idx = serial_op_dist_attr.impl_idx
165 166 167

    # check dist_attr between serial op and dist op
    for i in dist_op_idx:
168
        op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_ops[i])
169 170
        for in_varname in dist_ops[i].desc.input_arg_names():
            in_var = dist_main_prog.global_block().var(in_varname)
171
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
172
                in_var)
173
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
174 175 176 177 178 179
            in_var_dims_mapping = op_dist_attr.get_input_dims_mapping(
                in_varname)
            if tensor_dims_mapping != in_var_dims_mapping:
                equal = False
        for out_varname in dist_ops[i].desc.output_arg_names():
            out_var = dist_main_prog.global_block().var(out_varname)
180
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
181
                out_var)
182
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
183 184 185 186
            out_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
                out_varname)
            if tensor_dims_mapping != out_var_dims_mapping:
                equal = False
187 188
        dist_op_process_mesh = op_dist_attr.process_mesh
        dist_op_impl_idx = op_dist_attr.impl_idx
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        if serial_op.desc.id() == dist_ops[i].desc.id() or \
            serial_process_mesh != dist_op_process_mesh or \
            serial_impl_idx != dist_op_impl_idx:
            equal = False

    return equal


def distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                       dist_context, serial_op_idx,
                                       dist_op_idx):

    equal = True
    serial_ops = serial_main_prog.global_block().ops
    dist_ops = dist_main_prog.global_block().ops

    for i in range(len(serial_op_idx)):
        serial_op = serial_ops[serial_op_idx[i]]
        dist_op_0 = dist_ops[dist_op_idx[i][0]]
        if dist_op_0.type == "c_identity":
            # serial op input's dist_attr
            serial_in_dist_attr = get_input_var_dist_attr(
                serial_op, serial_main_prog, dist_context)
            # c_identity output's(new var) dist_attr
            identity_out_dist_attr = get_output_var_dist_attr(
                dist_op_0, dist_main_prog, dist_context)
            # check var dist_attr
            equal = check_equal_var_dist_attr(serial_in_dist_attr,
                                              identity_out_dist_attr)
        else:
            # serial op output's dist_attr
            serial_out_dist_attr = get_output_var_dist_attr(
                serial_op, serial_main_prog, dist_context)
            # dist op output's(new var) dist_attr
            out_dist_attr = get_output_var_dist_attr(dist_op_0, dist_main_prog,
                                                     dist_context)
            # check var dist_attr
            equal = check_equal_var_dist_attr(serial_out_dist_attr,
                                              out_dist_attr)

229
        # check op's dist_attr
230 231 232 233 234 235 236 237 238 239
        equal = check_equal_dist_op_attr(dist_context, dist_main_prog,
                                         serial_op, dist_ops, dist_op_idx[i])

    return equal


def distributed_attr_check_for_program(dist_main_prog, dist_context):
    have_dist_attr = True
    for block in dist_main_prog.blocks:
        for tensor in block.vars.values():
240
            var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
241 242 243 244 245
                tensor)
            if var_dist_attr is None:
                have_dist_attr = False

        for op in block.ops:
246
            op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
247 248 249 250 251 252
            if op_dist_attr is None:
                have_dist_attr = False

    return have_dist_attr


253
class MLPLayer(nn.Layer):
254

255 256 257 258 259 260 261 262
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
263 264
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
265 266
        bias_attr = None

267 268 269 270 271 272 273 274
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
275 276 277 278
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
279
        if _global_parallel_strategy in ["mp", "dp_mp"]:
280
            auto.shard_tensor(self.linear0.weight,
281 282
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
283
            auto.shard_tensor(self.linear1.weight,
284 285
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
286
        else:
287
            auto.shard_tensor(self.linear0.weight,
288 289
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, None])
290
            auto.shard_tensor(self.linear1.weight,
291 292
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, None])
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
309 310 311
        input = static.data(name="input",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
312

313
        if _global_parallel_strategy in ["dp", "dp_mp"]:
314
            auto.shard_tensor(input,
315 316
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None, None])
317 318 319 320 321

        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       dropout_ratio=0.1,
                       initializer_range=0.02)
322 323 324 325 326
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoPartitioner(unittest.TestCase):
327

328
    def test_mlp_dp(self):
329 330
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
331
        global _global_process_mesh
332 333
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["dp"])
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # parameter should not be partitioned
        self.assertTrue(
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
        self.assertTrue(
            is_all_parameters_shape_equal(serial_startup_prog,
                                          dist_startup_prog))

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

352
        # parameter initialization
353 354
        var_need_broadcast = []
        self.assertTrue(
355 356 357 358 359 360 361 362
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=None,
                                 dp_parallel_axis=0))
363 364

    def test_mlp_mp(self):
365 366
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
367
        global _global_process_mesh
368 369
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["mp"])
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_1.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
398 399
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
400 401 402
        ]
        self.assertTrue(dist_ops == ref_ops)

403
        # parameter initialization
404 405 406
        var_need_broadcast = sorted(
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
        self.assertTrue(
407 408 409 410 411 412 413 414
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=0,
                                 dp_parallel_axis=None))
415

416 417 418 419 420 421 422 423 424 425 426
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

427
    def test_mlp_dp_mp(self):
428 429
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
430
        global _global_process_mesh
431 432 433
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_1.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
462 463
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
464 465 466 467 468 469 470
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization
        var_need_broadcast = sorted(
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
        self.assertTrue(
471 472 473 474 475 476 477 478
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=1,
                                 dp_parallel_axis=0))
479

480 481 482 483 484 485 486 487 488 489 490
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

491 492

class AttentionLayer(nn.Layer):
493

494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
    def __init__(self,
                 hidden_size=1024,
                 sequence_len=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
515 516
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
517 518
        bias_attr = None

519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
535 536

    def forward(self, input):
537
        if _global_parallel_strategy in ["dp", "dp_mp"]:
538
            auto.shard_tensor(input,
539 540
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None, None])
541 542 543 544 545 546 547 548

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

549
        if _global_parallel_strategy in ["mp", "dp_mp"]:
550
            auto.shard_tensor(self.q_proj.weight,
551 552
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
553
            auto.shard_tensor(self.k_proj.weight,
554 555
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
556
            auto.shard_tensor(self.v_proj.weight,
557 558
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
559 560 561 562 563 564 565

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
566 567 568 569
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
570 571 572 573 574 575 576

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
577 578 579 580
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
581 582 583 584 585 586 587 588 589

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
590 591

        if _global_parallel_strategy in ["mp", "dp_mp"]:
592
            auto.shard_tensor(self.out_proj.weight,
593 594
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
595 596 597 598 599 600 601 602 603 604

        return out


def attn_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
605 606 607 608 609 610 611 612 613
        input = static.data(name="query",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
        attn = AttentionLayer(hidden_size=hidden_size,
                              sequence_len=sequence_len,
                              intermediate_size=4 * hidden_size,
                              num_heads=16,
                              dropout_ratio=0.1,
                              initializer_range=0.02)
614 615 616 617 618 619
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoPartitioner(unittest.TestCase):
620

621
    def test_attn_dp(self):
622 623
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
624
        global _global_process_mesh
625 626
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["dp"])
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)
        # parameter should not be partitioned
        self.assertTrue(
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
        self.assertTrue(
            is_all_parameters_shape_equal(serial_startup_prog,
                                          dist_startup_prog))

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

644
        # parameter initialization
645 646
        var_need_broadcast = []
        self.assertTrue(
647 648 649 650 651 652 653 654
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=None,
                                 dp_parallel_axis=0))
655 656

    def test_attn_mp(self):
657 658
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
659
        global _global_process_mesh
660 661
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["mp"])
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_3.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
691 692 693 694 695 696
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
            'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add'
697 698 699
        ]
        self.assertTrue(dist_ops == ref_ops)

700
        # parameter initialization
701 702
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
703 704 705 706 707 708 709 710
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=0,
                                 dp_parallel_axis=None))
711

712 713 714 715 716 717 718 719 720 721 722
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

723
    def test_attn_dp_mp(self):
724 725
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
726
        global _global_process_mesh
727 728 729
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_3.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
759 760 761 762 763 764
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
            'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add'
765 766 767
        ]
        self.assertTrue(dist_ops == ref_ops)

768
        # parameter initialization
769 770
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
771 772 773 774 775 776 777 778
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=1,
                                 dp_parallel_axis=0))
779

780 781 782 783 784 785 786 787 788 789 790
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

791 792

class DecoderLayer(nn.Layer):
793

794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
    def __init__(self,
                 vocab_size=32768,
                 hidden_size=1024,
                 sequence_len=512,
                 max_position_embeddings=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
823 824 825 826
            weight_attr=paddle.ParamAttr(name="word_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
827 828 829
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
830 831 832 833
            weight_attr=paddle.ParamAttr(name="pos_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
834 835 836 837

        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
854 855 856 857 858 859 860

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
861 862 863 864 865 866 867 868
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
869 870 871 872 873 874
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
875
        if _global_parallel_strategy in ["dp", "dp_mp"]:
876
            auto.shard_tensor(input_ids,
877 878
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None])
879 880 881 882

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

883
        if _global_parallel_strategy in ["mp", "dp_mp"]:
884
            auto.shard_tensor(self.word_embeddings.weight,
885 886
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
        target = self.norm(embeddings)

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

902
        if _global_parallel_strategy in ["mp", "dp_mp"]:
903
            auto.shard_tensor(self.q_proj.weight,
904 905
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
906
            auto.shard_tensor(self.k_proj.weight,
907 908
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
909
            auto.shard_tensor(self.v_proj.weight,
910 911
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
912 913 914 915 916 917 918

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
919 920 921 922
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
923 924 925 926 927 928 929

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
930 931 932 933
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
934 935 936 937 938 939 940 941 942 943

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

944
        if _global_parallel_strategy in ["mp", "dp_mp"]:
945
            auto.shard_tensor(self.out_proj.weight,
946 947
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
948
        else:
949
            auto.shard_tensor(self.out_proj.weight,
950 951
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, None])
952 953 954 955 956 957 958 959 960 961 962 963

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
        out0 = self.norm(residual)

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

964
        if _global_parallel_strategy in ["mp", "dp_mp"]:
965
            auto.shard_tensor(self.linear0.weight,
966 967
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
968
            auto.shard_tensor(self.linear1.weight,
969 970
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
971 972 973 974 975 976 977 978 979 980 981 982

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
983 984 985 986 987 988 989 990 991 992 993 994 995 996
        input_ids = static.data(name="input_ids",
                                shape=[batch_size, sequence_len],
                                dtype='int64')
        position_ids = static.data(name="position_ids",
                                   shape=[batch_size, sequence_len],
                                   dtype='int64')
        decoder = DecoderLayer(vocab_size=32768,
                               hidden_size=hidden_size,
                               sequence_len=sequence_len,
                               max_position_embeddings=512,
                               intermediate_size=4 * hidden_size,
                               num_heads=16,
                               dropout_ratio=0.1,
                               initializer_range=0.02)
997 998 999 1000 1001 1002
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerPartitioner(unittest.TestCase):
1003

1004
    def test_decoder_dp_mp(self):
1005 1006
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
1007
        global _global_process_mesh
1008 1009 1010
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            decoder_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = [
            'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = [
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'c_embedding', 'c_allreduce_sum', 'lookup_table_v2',
1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
            'elementwise_add', 'dropout', 'layer_norm', 'c_identity',
            'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
            'c_identity', 'matmul_v2', 'elementwise_add', 'c_identity',
            'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
            'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout',
            'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add', 'dropout', 'elementwise_add',
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout',
            'elementwise_add'
1057 1058 1059
        ]
        self.assertTrue(dist_ops == ref_ops)

1060
        # parameter initialization
1061 1062 1063 1064 1065
        var_need_broadcast = sorted([
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ])
        self.assertTrue(
1066 1067 1068 1069 1070 1071 1072 1073
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=1,
                                 dp_parallel_axis=0))
1074

1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr
        serial_op_idx = [0, 5, 9, 11, 23, 28, 31]
        dist_op_idx = [[0, 1], [6, 7], [11, 12], [14, 15], [27, 28], [33, 34],
                       [37, 38]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

1087
    def test_decoder_noparallel(self):
1088 1089
        global _global_parallel_strategy
        _global_parallel_strategy = "None"
1090
        global _global_process_mesh
1091 1092 1093
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["x", "y"])
1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            decoder_pretrain_forward)

        # param should be partition
        nrank = 1
        # col parallel
        weights = [
            'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = [
            'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = [
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout',
1130 1131 1132 1133 1134 1135 1136
            'layer_norm', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'matmul_v2', 'elementwise_add', 'matmul_v2',
            'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
            'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
            'transpose2', 'reshape2', 'matmul_v2', 'elementwise_add', 'dropout',
            'elementwise_add', 'layer_norm', 'matmul_v2', 'elementwise_add',
            'gelu', 'matmul_v2', 'elementwise_add', 'dropout', 'elementwise_add'
1137 1138 1139 1140 1141 1142 1143 1144 1145
        ]
        self.assertTrue(dist_ops == ref_ops)
        dist_ops = dist_startup_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'gaussian_random', 'gaussian_random', 'gaussian_random',
            'fill_constant', 'gaussian_random', 'fill_constant',
            'gaussian_random', 'fill_constant', 'gaussian_random',
            'fill_constant', 'gaussian_random', 'fill_constant',
1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
            'gaussian_random', 'fill_constant', 'fill_constant',
            'fill_constant', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast'
1156 1157 1158 1159 1160 1161
        ]
        self.assertTrue(dist_ops == ref_ops)


if __name__ == "__main__":
    unittest.main()