test_auto_parallel_completion.py 30.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import unittest.mock
from io import StringIO

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
31
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
32
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
33 34
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
35
paddle.enable_static()
36
_global_parallel_strategy = None
37
_global_process_mesh = None
38
_global_process_mesh2 = None
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61


class MLPLayer(nn.Layer):
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=initializer_range))
        bias_attr = None

        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
62
        if _global_parallel_strategy == "mp":
63
            auto.shard_tensor(
64 65 66 67 68 69 70 71 72 73 74
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
            auto.shard_tensor(
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
75
        elif _global_parallel_strategy == "dp_mp":
76
            auto.shard_tensor(
77 78 79 80 81 82 83 84 85 86 87
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
            auto.shard_tensor(
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
88 89
        elif _global_parallel_strategy == "pp":
            auto.shard_tensor(
90 91 92 93 94
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
95
            auto.shard_tensor(
96 97 98 99 100
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh2,
                    "dims_mapping": [1, -1]
                })
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32')

122
        if _global_parallel_strategy == "dp":
123
            auto.shard_tensor(
124 125 126 127 128
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
129
        elif _global_parallel_strategy == "dp_mp":
130
            auto.shard_tensor(
131 132 133 134 135
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
136 137 138 139 140 141 142 143 144 145 146 147

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
    def test_mlp_dp(self):
148 149
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
150
        global _global_process_mesh
151
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
152 153 154 155 156 157 158
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = mlp_pretrain_forward(train_program,
                                                            start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
159
        # print_program_with_dist_attr(complete_train_program,
160
        #                                     dist_context)
161
        self.assertTrue(dist_context.validate_dist_attr_for_program())
162 163

    def test_mlp_mp(self):
164 165
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
166
        global _global_process_mesh
167
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
168 169 170 171 172 173 174 175

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = mlp_pretrain_forward(train_program,
                                                            start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
176
        # print_program_with_dist_attr(complete_train_program,
177
        #                                     dist_context)
178
        self.assertTrue(dist_context.validate_dist_attr_for_program())
179 180

    def test_mlp_dp_mp(self):
181 182
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
183 184
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
185
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
186 187 188 189 190 191 192 193

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = mlp_pretrain_forward(train_program,
                                                            start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
194
        # print_program_with_dist_attr(complete_train_program,
195
        #                                     dist_context)
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
        self.assertTrue(dist_context.validate_dist_attr_for_program())

    # def test_mlp_misc(self):
    #     # import pdb
    #     global _global_parallel_strategy
    #     _global_parallel_strategy = "pp"
    #     global _global_process_mesh
    #     _global_process_mesh = auto.ProcessMesh(
    #         mesh=[[0, 1], [2, 3]])
    #     global _global_process_mesh2
    #     _global_process_mesh2 = auto.ProcessMesh(
    #         mesh=[[4, 5], [6, 7]])

    #     train_program = static.Program()
    #     start_program = static.Program()
    #     dist_context = DistributedContext()
    #     train_program, start_program = mlp_pretrain_forward(train_program,
    #                                                         start_program)
    #     # pdb.set_trace()
    #     complete_train_program = auto.complete_annotation(train_program,
    #                                                       dist_context)
    #     # print_program_with_dist_attr(complete_train_program,
    #     #                                     dist_context)
    #     dist_context.finalize_distributed_attr_for_program(
    #         complete_train_program)
    #     from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
    #     for block in complete_train_program.blocks:
    #         for tensor in block.vars.values():
    #             desc = tensor.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             attr_name = append_distributed_attr_suffix("dims_mapping")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #         for op in block.ops:
    #             desc = op.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.input_arg_names():
    #                 attr_name = append_distributed_attr_suffix("IN_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.output_arg_names():
    #                 attr_name = append_distributed_attr_suffix("OUT_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #     set_default_distributed_context(dist_context)
    #     self.assertTrue("dist_attr" in str(complete_train_program))
    #     with unittest.mock.patch(
    #             "sys.stdout", new_callable=StringIO) as mock_stdout:
    #         print_program_with_dist_attr(complete_train_program)
    #         self.assertIsNotNone(mock_stdout.getvalue())
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284


class AttentionLayer(nn.Layer):
    def __init__(self,
                 hidden_size=1024,
                 sequence_len=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=initializer_range))
        bias_attr = None

        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)

    def forward(self, input):
285
        if _global_parallel_strategy == "dp":
286
            auto.shard_tensor(
287 288 289 290 291
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
292
        elif _global_parallel_strategy == "dp_mp":
293
            auto.shard_tensor(
294 295 296 297 298
                input,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1, -1]
                })
299 300 301 302 303 304 305 306

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

307
        if _global_parallel_strategy == "mp":
308
            auto.shard_tensor(
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
            auto.shard_tensor(
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
            auto.shard_tensor(
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
326
        elif _global_parallel_strategy == "dp_mp":
327
            auto.shard_tensor(
328 329 330 331 332
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
333
            auto.shard_tensor(
334 335 336 337 338
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
339
            auto.shard_tensor(
340 341 342 343 344
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
375
        if _global_parallel_strategy == "mp":
376
            auto.shard_tensor(
377 378 379 380 381
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
382
        elif _global_parallel_strategy == "dp_mp":
383
            auto.shard_tensor(
384 385 386 387 388
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416

        return out


def attn_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32')
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
    def test_attn_dp(self):
417 418
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
419
        global _global_process_mesh
420
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
421 422 423 424 425 426 427
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = attn_pretrain_forward(train_program,
                                                             start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
428
        # print_program_with_dist_attr(complete_train_program,
429
        #                                     dist_context)
430
        self.assertTrue(dist_context.validate_dist_attr_for_program())
431 432

    def test_attn_mp(self):
433 434
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
435
        global _global_process_mesh
436
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
437 438 439 440 441 442 443 444

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = attn_pretrain_forward(train_program,
                                                             start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
445
        # print_program_with_dist_attr(complete_train_program,
446
        #                                     dist_context)
447
        self.assertTrue(dist_context.validate_dist_attr_for_program())
448 449

    def test_attn_dp_mp(self):
450 451
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
452 453
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
454
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
455 456 457 458 459 460 461 462

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = attn_pretrain_forward(train_program,
                                                             start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
463
        # print_program_with_dist_attr(complete_train_program,
464
        #                                     dist_context)
465
        self.assertTrue(dist_context.validate_dist_attr_for_program())
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531


class DecoderLayer(nn.Layer):
    def __init__(self,
                 vocab_size=32768,
                 hidden_size=1024,
                 sequence_len=512,
                 max_position_embeddings=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range)))
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range)))

        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
532 533
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
534 535 536 537 538
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
539
        if _global_parallel_strategy == "dp":
540
            auto.shard_tensor(
541 542 543 544 545
                input_ids,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
546
        elif _global_parallel_strategy == "dp_mp":
547
            auto.shard_tensor(
548 549 550 551 552
                input_ids,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
553 554 555 556

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

557
        if _global_parallel_strategy == "mp":
558 559
            auto.shard_tensor(
                self.word_embeddings.weight,
560 561 562 563
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
564
        elif _global_parallel_strategy == "dp_mp":
565 566
            auto.shard_tensor(
                self.word_embeddings.weight,
567 568 569 570
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
571 572 573 574 575

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
576
        target = self.norm1(embeddings)
577 578 579 580 581 582 583 584 585

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

586
        if _global_parallel_strategy == "mp":
587
            auto.shard_tensor(
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
            auto.shard_tensor(
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
            auto.shard_tensor(
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
605
        elif _global_parallel_strategy == "dp_mp":
606
            auto.shard_tensor(
607 608 609 610 611
                self.q_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
612
            auto.shard_tensor(
613 614 615 616 617
                self.k_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
618
            auto.shard_tensor(
619 620 621 622 623
                self.v_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

655
        if _global_parallel_strategy == "mp":
656
            auto.shard_tensor(
657 658 659 660 661
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
662
        elif _global_parallel_strategy == "dp_mp":
663
            auto.shard_tensor(
664 665 666 667 668
                self.out_proj.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
669 670 671 672 673

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
674
        out0 = self.norm2(residual)
675 676 677 678 679 680

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

681
        if _global_parallel_strategy == "mp":
682
            auto.shard_tensor(
683 684 685 686 687 688 689 690 691 692 693
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 0]
                })
            auto.shard_tensor(
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [0, -1]
                })
694
        elif _global_parallel_strategy == "dp_mp":
695
            auto.shard_tensor(
696 697 698 699 700
                self.linear0.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [-1, 1]
                })
701
            auto.shard_tensor(
702 703 704 705 706
                self.linear1.weight,
                dist_attr={
                    "process_mesh": _global_process_mesh,
                    "dims_mapping": [1, -1]
                })
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
        position_ids = static.data(
            name="position_ids",
            shape=[batch_size, sequence_len],
            dtype='int64')
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
    def test_decoder_dp(self):
741 742
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
743
        global _global_process_mesh
744
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
745 746 747 748 749 750 751
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = decoder_pretrain_forward(train_program,
                                                                start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
752
        # print_program_with_dist_attr(complete_train_program,
753
        #                                     dist_context)
754
        self.assertTrue(dist_context.validate_dist_attr_for_program())
755 756

    def test_decoder_mp(self):
757 758
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
759
        global _global_process_mesh
760
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
761 762 763 764 765 766 767 768

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = decoder_pretrain_forward(train_program,
                                                                start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
769
        # print_program_with_dist_attr(complete_train_program,
770
        #                                     dist_context)
771
        self.assertTrue(dist_context.validate_dist_attr_for_program())
772 773

    def test_decoder_dp_mp(self):
774 775
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
776 777
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
778
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
779 780 781 782 783 784 785 786

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = decoder_pretrain_forward(train_program,
                                                                start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
787
        # print_program_with_dist_attr(complete_train_program,
788
        #                                     dist_context)
789
        self.assertTrue(dist_context.validate_dist_attr_for_program())
790 791 792 793


if __name__ == "__main__":
    unittest.main()