test_auto_parallel_partitioner.py 47.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest.mock

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
25
from paddle.distributed.fleet import auto
26
from paddle.distributed.auto_parallel.completion import Completer
27
from paddle.distributed.auto_parallel.dist_context import DistributedContext
28 29 30
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
31
from paddle.distributed.auto_parallel.process_group import new_process_group
32 33

paddle.enable_static()
34
_global_parallel_strategy = None
35 36 37 38 39 40 41 42
_global_process_mesh = None


def get_programs(annotated_func):
    train_program = static.Program()
    start_program = static.Program()
    dist_context = DistributedContext()
    global _global_process_mesh
43
    dist_context.process_mesh = _global_process_mesh
44
    train_program, start_program = annotated_func(train_program, start_program)
45 46 47
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
48
    dist_context.block_state.parse_forward_blocks(complete_train_program)
49 50 51

    rank_id = 3
    dist_strategy = fleet.DistributedStrategy()
52 53 54
    partitioner = Partitioner(dist_context, rank_id)
    test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, _ = partitioner.partition(
        complete_train_program, start_program, [])
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

    return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context


def is_all_parameters_shape_equal(prog1, prog2):

    params1 = prog1.all_parameters()
    params2 = prog2.all_parameters()
    params1.sort(key=lambda x: x.name)
    params2.sort(key=lambda x: x.name)
    shape1 = [tensor.shape for tensor in params1]
    shape2 = [tensor.shape for tensor in params2]

    if len(shape1) != len(shape2):
        return False
    for i in range(len(shape1)):
        if shape1[i] != shape2[i]:
            return False
    return True


def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):

    for i in range(len(varnames1)):
        var1 = prog1.global_block().var(varnames1[i])
        var2 = prog2.global_block().var(varnames2[i])
        if var1.shape[axis] != (var2.shape[axis] // nsplit):
            return False

    return True


def initialization_check(mode, dist_context, dist_startup_prog,
88 89
                         serial_startup_prog, var_need_broadcast, process_mesh,
                         mp_parallel_axis, dp_parallel_axis):
90
    if 'mp' in mode:
91 92 93
        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, mp_parallel_axis,
                                      3)
94 95
        mp_ring_id = new_process_group(group_ranks).id
        broadcast_ops = [
96 97
            op for op in dist_startup_prog.global_block().ops if
            (op.type == "c_broadcast" and op.desc.attr("ring_id") == mp_ring_id)
98 99 100 101 102 103 104
        ]
        broadcast_varnames = sorted(
            [op.desc.output_arg_names()[0] for op in broadcast_ops])
        if broadcast_varnames != var_need_broadcast:
            return False

    if 'dp' in mode:
105 106 107
        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, dp_parallel_axis,
                                      3)
108 109 110
        dp_ring_id = new_process_group(group_ranks).id
        nparam = len(serial_startup_prog.all_parameters())
        nbroadcast_dp = len([
111 112
            op for op in dist_startup_prog.global_block().ops if
            (op.type == "c_broadcast" and op.desc.attr("ring_id") == dp_ring_id)
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        ])
        if nparam != nbroadcast_dp:
            return False

    if "dp" in mode and 'mp' in mode:
        nbroadcast = len([
            op for op in dist_startup_prog.global_block().ops
            if op.type == "c_broadcast"
        ])
        if len(var_need_broadcast) + nbroadcast_dp != nbroadcast:
            return False

    return True


128 129 130
def get_input_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.input_arg_names()
    var = main_program.global_block().var(varname[0])
131
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
132 133 134 135 136 137
    return dist_attr


def get_output_var_dist_attr(op, main_program, dist_context):
    varname = op.desc.output_arg_names()
    var = main_program.global_block().var(varname[0])
138
    dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
139 140 141 142 143
    return dist_attr


def check_equal_var_dist_attr(serial_dist_attr, dist_attr):
    equal = True
144 145
    if serial_dist_attr.process_mesh != dist_attr.process_mesh or \
        serial_dist_attr.dims_mapping != dist_attr.dims_mapping:
146 147 148 149 150 151 152 153
        equal = False
    return equal


def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops,
                             dist_op_idx):
    equal = True
    # get serial op's process_mesh and impl_idx
154 155 156
    serial_op_dist_attr = dist_context.get_op_dist_attr_for_program(serial_op)
    serial_process_mesh = serial_op_dist_attr.process_mesh
    serial_impl_idx = serial_op_dist_attr.impl_idx
157 158 159

    # check dist_attr between serial op and dist op
    for i in dist_op_idx:
160
        op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_ops[i])
161 162
        for in_varname in dist_ops[i].desc.input_arg_names():
            in_var = dist_main_prog.global_block().var(in_varname)
163
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
164
                in_var)
165
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
166 167 168 169 170 171
            in_var_dims_mapping = op_dist_attr.get_input_dims_mapping(
                in_varname)
            if tensor_dims_mapping != in_var_dims_mapping:
                equal = False
        for out_varname in dist_ops[i].desc.output_arg_names():
            out_var = dist_main_prog.global_block().var(out_varname)
172
            tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
173
                out_var)
174
            tensor_dims_mapping = tensor_dist_attr.dims_mapping
175 176 177 178
            out_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
                out_varname)
            if tensor_dims_mapping != out_var_dims_mapping:
                equal = False
179 180
        dist_op_process_mesh = op_dist_attr.process_mesh
        dist_op_impl_idx = op_dist_attr.impl_idx
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        if serial_op.desc.id() == dist_ops[i].desc.id() or \
            serial_process_mesh != dist_op_process_mesh or \
            serial_impl_idx != dist_op_impl_idx:
            equal = False

    return equal


def distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                       dist_context, serial_op_idx,
                                       dist_op_idx):

    equal = True
    serial_ops = serial_main_prog.global_block().ops
    dist_ops = dist_main_prog.global_block().ops

    for i in range(len(serial_op_idx)):
        serial_op = serial_ops[serial_op_idx[i]]
        dist_op_0 = dist_ops[dist_op_idx[i][0]]
        if dist_op_0.type == "c_identity":
            # serial op input's dist_attr
            serial_in_dist_attr = get_input_var_dist_attr(
                serial_op, serial_main_prog, dist_context)
            # c_identity output's(new var) dist_attr
            identity_out_dist_attr = get_output_var_dist_attr(
                dist_op_0, dist_main_prog, dist_context)
            # check var dist_attr
            equal = check_equal_var_dist_attr(serial_in_dist_attr,
                                              identity_out_dist_attr)
        else:
            # serial op output's dist_attr
            serial_out_dist_attr = get_output_var_dist_attr(
                serial_op, serial_main_prog, dist_context)
            # dist op output's(new var) dist_attr
            out_dist_attr = get_output_var_dist_attr(dist_op_0, dist_main_prog,
                                                     dist_context)
            # check var dist_attr
            equal = check_equal_var_dist_attr(serial_out_dist_attr,
                                              out_dist_attr)

221
        # check op's dist_attr
222 223 224 225 226 227 228 229 230
        equal = check_equal_dist_op_attr(dist_context, dist_main_prog,
                                         serial_op, dist_ops, dist_op_idx[i])

    return equal


def distributed_attr_check_for_program(dist_main_prog, dist_context):
    have_dist_attr = True
    for block in dist_main_prog.blocks:
231 232
        for var in block.vars.values():
            var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
233 234 235 236
            if var_dist_attr is None:
                have_dist_attr = False

        for op in block.ops:
237
            op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
238 239 240 241 242 243
            if op_dist_attr is None:
                have_dist_attr = False

    return have_dist_attr


244
class MLPLayer(nn.Layer):
245

246 247 248 249 250 251 252 253
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
254 255
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
256 257
        bias_attr = None

258 259 260 261 262 263 264 265
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
266 267 268 269
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
270
        if _global_parallel_strategy in ["mp", "dp_mp"]:
271
            auto.shard_tensor(self.linear0.weight,
272 273
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
274
            auto.shard_tensor(self.linear1.weight,
275 276
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
277
        else:
278
            auto.shard_tensor(self.linear0.weight,
279 280
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, None])
281
            auto.shard_tensor(self.linear1.weight,
282 283
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, None])
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
300 301 302
        input = static.data(name="input",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
303

304
        if _global_parallel_strategy in ["dp", "dp_mp"]:
305
            auto.shard_tensor(input,
306 307
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None, None])
308 309 310 311 312

        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       dropout_ratio=0.1,
                       initializer_range=0.02)
313 314 315 316 317
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoPartitioner(unittest.TestCase):
318

319
    def test_mlp_dp(self):
320 321
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
322
        global _global_process_mesh
323 324
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["dp"])
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # parameter should not be partitioned
        self.assertTrue(
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
        self.assertTrue(
            is_all_parameters_shape_equal(serial_startup_prog,
                                          dist_startup_prog))

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

343
        # parameter initialization
344 345
        var_need_broadcast = []
        self.assertTrue(
346 347 348 349 350 351 352 353
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=None,
                                 dp_parallel_axis=0))
354 355

    def test_mlp_mp(self):
356 357
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
358
        global _global_process_mesh
359 360
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["mp"])
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_1.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
389 390
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
391 392 393
        ]
        self.assertTrue(dist_ops == ref_ops)

394
        # parameter initialization
395 396 397
        var_need_broadcast = sorted(
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
        self.assertTrue(
398 399 400 401 402 403 404 405
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=0,
                                 dp_parallel_axis=None))
406

407 408 409 410 411 412 413 414 415 416 417
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

418
    def test_mlp_dp_mp(self):
419 420
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
421
        global _global_process_mesh
422 423 424
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            mlp_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_1.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_1.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
453 454
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout'
455 456 457 458 459 460 461
        ]
        self.assertTrue(dist_ops == ref_ops)

        # parameter initialization
        var_need_broadcast = sorted(
            ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
        self.assertTrue(
462 463 464 465 466 467 468 469
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=1,
                                 dp_parallel_axis=0))
470

471 472 473 474 475 476 477 478 479 480 481
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [1, 4]
        dist_op_idx = [[1, 2], [5, 6]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

482 483

class AttentionLayer(nn.Layer):
484

485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
    def __init__(self,
                 hidden_size=1024,
                 sequence_len=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
506 507
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
508 509
        bias_attr = None

510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
526 527

    def forward(self, input):
528
        if _global_parallel_strategy in ["dp", "dp_mp"]:
529
            auto.shard_tensor(input,
530 531
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None, None])
532 533 534 535 536 537 538 539

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

540
        if _global_parallel_strategy in ["mp", "dp_mp"]:
541
            auto.shard_tensor(self.q_proj.weight,
542 543
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
544
            auto.shard_tensor(self.k_proj.weight,
545 546
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
547
            auto.shard_tensor(self.v_proj.weight,
548 549
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
550 551 552 553 554 555 556

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
557 558 559 560
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
561 562 563 564 565 566 567

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
568 569 570 571
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
572 573 574 575 576 577 578 579 580

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
581 582

        if _global_parallel_strategy in ["mp", "dp_mp"]:
583
            auto.shard_tensor(self.out_proj.weight,
584 585
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
586 587 588 589 590 591 592 593 594 595

        return out


def attn_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
596 597 598 599 600 601 602 603 604
        input = static.data(name="query",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
        attn = AttentionLayer(hidden_size=hidden_size,
                              sequence_len=sequence_len,
                              intermediate_size=4 * hidden_size,
                              num_heads=16,
                              dropout_ratio=0.1,
                              initializer_range=0.02)
605 606 607 608 609 610
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoPartitioner(unittest.TestCase):
611

612
    def test_attn_dp(self):
613 614
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
615
        global _global_process_mesh
616 617
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["dp"])
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)
        # parameter should not be partitioned
        self.assertTrue(
            is_all_parameters_shape_equal(serial_main_prog, dist_main_prog))
        self.assertTrue(
            is_all_parameters_shape_equal(serial_startup_prog,
                                          dist_startup_prog))

        # op in main prog should be the same
        serial_ops = serial_main_prog.global_block().ops
        dist_ops = dist_main_prog.global_block().ops
        serial_ops = [op.type for op in serial_ops]
        dist_ops = [op.type for op in dist_ops]
        self.assertTrue(serial_ops == dist_ops)

635
        # parameter initialization
636 637
        var_need_broadcast = []
        self.assertTrue(
638 639 640 641 642 643 644 645
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=None,
                                 dp_parallel_axis=0))
646 647

    def test_attn_mp(self):
648 649
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
650
        global _global_process_mesh
651 652
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["mp"])
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_3.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
682 683 684 685 686 687
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
            'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add'
688 689 690
        ]
        self.assertTrue(dist_ops == ref_ops)

691
        # parameter initialization
692 693
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
694 695 696 697 698 699 700 701
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=0,
                                 dp_parallel_axis=None))
702

703 704 705 706 707 708 709 710 711 712 713
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

714
    def test_attn_dp_mp(self):
715 716
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
717
        global _global_process_mesh
718 719 720
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749

        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            attn_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['linear_3.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = ['linear_3.b_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
750 751 752 753 754 755
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'c_identity', 'matmul_v2', 'elementwise_add',
            'c_identity', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'reshape2', 'transpose2', 'matmul', 'softmax',
            'dropout', 'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add'
756 757 758
        ]
        self.assertTrue(dist_ops == ref_ops)

759
        # parameter initialization
760 761
        var_need_broadcast = ['linear_3.b_0']
        self.assertTrue(
762 763 764 765 766 767 768 769
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=1,
                                 dp_parallel_axis=0))
770

771 772 773 774 775 776 777 778 779 780 781
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr for dist op
        serial_op_idx = [0, 4, 6, 18]
        dist_op_idx = [[0, 1], [5, 6], [8, 9], [21, 22]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

782 783

class DecoderLayer(nn.Layer):
784

785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813
    def __init__(self,
                 vocab_size=32768,
                 hidden_size=1024,
                 sequence_len=512,
                 max_position_embeddings=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
814 815 816 817
            weight_attr=paddle.ParamAttr(name="word_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
818 819 820
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
821 822 823 824
            weight_attr=paddle.ParamAttr(name="pos_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
825 826 827 828

        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
845 846 847 848 849 850 851

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
852 853 854 855 856 857 858 859
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
860 861 862 863 864 865
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
866
        if _global_parallel_strategy in ["dp", "dp_mp"]:
867
            auto.shard_tensor(input_ids,
868 869
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None])
870 871 872 873

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

874
        if _global_parallel_strategy in ["mp", "dp_mp"]:
875
            auto.shard_tensor(self.word_embeddings.weight,
876 877
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
        target = self.norm(embeddings)

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

893
        if _global_parallel_strategy in ["mp", "dp_mp"]:
894
            auto.shard_tensor(self.q_proj.weight,
895 896
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
897
            auto.shard_tensor(self.k_proj.weight,
898 899
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
900
            auto.shard_tensor(self.v_proj.weight,
901 902
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
903 904 905 906 907 908 909

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
910 911 912 913
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
914 915 916 917 918 919 920

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
921 922 923 924
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
925 926 927 928 929 930 931 932 933 934

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

935
        if _global_parallel_strategy in ["mp", "dp_mp"]:
936
            auto.shard_tensor(self.out_proj.weight,
937 938
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
939
        else:
940
            auto.shard_tensor(self.out_proj.weight,
941 942
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, None])
943 944 945 946 947 948 949 950 951 952 953 954

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
        out0 = self.norm(residual)

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

955
        if _global_parallel_strategy in ["mp", "dp_mp"]:
956
            auto.shard_tensor(self.linear0.weight,
957 958
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
959
            auto.shard_tensor(self.linear1.weight,
960 961
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
962 963 964 965 966 967 968 969 970 971 972 973

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
974 975 976 977 978 979 980 981 982 983 984 985 986 987
        input_ids = static.data(name="input_ids",
                                shape=[batch_size, sequence_len],
                                dtype='int64')
        position_ids = static.data(name="position_ids",
                                   shape=[batch_size, sequence_len],
                                   dtype='int64')
        decoder = DecoderLayer(vocab_size=32768,
                               hidden_size=hidden_size,
                               sequence_len=sequence_len,
                               max_position_embeddings=512,
                               intermediate_size=4 * hidden_size,
                               num_heads=16,
                               dropout_ratio=0.1,
                               initializer_range=0.02)
988 989 990 991 992 993
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerPartitioner(unittest.TestCase):
994

995
    def test_decoder_dp_mp(self):
996 997
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
998
        global _global_process_mesh
999 1000 1001
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            decoder_pretrain_forward)

        # param should be partition
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = [
            'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = [
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'c_embedding', 'c_allreduce_sum', 'lookup_table_v2',
1038 1039 1040 1041 1042 1043 1044 1045 1046 1047
            'elementwise_add', 'dropout', 'layer_norm', 'c_identity',
            'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
            'c_identity', 'matmul_v2', 'elementwise_add', 'c_identity',
            'matmul_v2', 'elementwise_add', 'reshape2', 'transpose2',
            'reshape2', 'transpose2', 'matmul', 'softmax', 'dropout',
            'matmul_v2', 'transpose2', 'reshape2', 'matmul_v2',
            'c_allreduce_sum', 'elementwise_add', 'dropout', 'elementwise_add',
            'layer_norm', 'c_identity', 'matmul_v2', 'elementwise_add', 'gelu',
            'matmul_v2', 'c_allreduce_sum', 'elementwise_add', 'dropout',
            'elementwise_add'
1048 1049 1050
        ]
        self.assertTrue(dist_ops == ref_ops)

1051
        # parameter initialization
1052 1053 1054 1055 1056
        var_need_broadcast = sorted([
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ])
        self.assertTrue(
1057 1058 1059 1060 1061 1062 1063 1064
            initialization_check(_global_parallel_strategy,
                                 dist_context,
                                 dist_startup_prog,
                                 serial_startup_prog,
                                 var_need_broadcast,
                                 _global_process_mesh,
                                 mp_parallel_axis=1,
                                 dp_parallel_axis=0))
1065

1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
        # check var and op all have dist_attr in dist_main_program
        self.assertTrue(
            distributed_attr_check_for_program(dist_main_prog, dist_context))
        # check distribured attr
        serial_op_idx = [0, 5, 9, 11, 23, 28, 31]
        dist_op_idx = [[0, 1], [6, 7], [11, 12], [14, 15], [27, 28], [33, 34],
                       [37, 38]]
        self.assertTrue(
            distributed_attr_check_for_dist_op(serial_main_prog, dist_main_prog,
                                               dist_context, serial_op_idx,
                                               dist_op_idx))

1078
    def test_decoder_noparallel(self):
1079 1080
        global _global_parallel_strategy
        _global_parallel_strategy = "None"
1081
        global _global_process_mesh
1082 1083 1084
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["x", "y"])
1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120
        serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
            decoder_pretrain_forward)

        # param should be partition
        nrank = 1
        # col parallel
        weights = [
            'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 1, nrank))
        weights = [
            'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        # row parallel
        weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0']
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, nrank))
        weights = [
            'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0',
            'layer_norm_0.w_0', 'linear_5.b_0'
        ]
        self.assertTrue(
            check_tensor_split(dist_main_prog, weights, serial_main_prog,
                               weights, 0, 1))

        # row and col allreduce
        dist_ops = dist_main_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout',
1121 1122 1123 1124 1125 1126 1127
            'layer_norm', 'matmul_v2', 'elementwise_add', 'reshape2',
            'transpose2', 'matmul_v2', 'elementwise_add', 'matmul_v2',
            'elementwise_add', 'reshape2', 'transpose2', 'reshape2',
            'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2',
            'transpose2', 'reshape2', 'matmul_v2', 'elementwise_add', 'dropout',
            'elementwise_add', 'layer_norm', 'matmul_v2', 'elementwise_add',
            'gelu', 'matmul_v2', 'elementwise_add', 'dropout', 'elementwise_add'
1128 1129 1130 1131 1132 1133 1134 1135 1136
        ]
        self.assertTrue(dist_ops == ref_ops)
        dist_ops = dist_startup_prog.global_block().ops
        dist_ops = [op.type for op in dist_ops]
        ref_ops = [
            'gaussian_random', 'gaussian_random', 'gaussian_random',
            'fill_constant', 'gaussian_random', 'fill_constant',
            'gaussian_random', 'fill_constant', 'gaussian_random',
            'fill_constant', 'gaussian_random', 'fill_constant',
1137 1138 1139 1140 1141 1142 1143 1144 1145 1146
            'gaussian_random', 'fill_constant', 'fill_constant',
            'fill_constant', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
            'c_broadcast'
1147 1148 1149 1150 1151 1152
        ]
        self.assertTrue(dist_ops == ref_ops)


if __name__ == "__main__":
    unittest.main()