test_auto_parallel_completion.py 32.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import unittest.mock
from io import StringIO

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
30
from paddle.distributed.auto_parallel.completion import Completer
31
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
32
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
33
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
34 35
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
36

37
paddle.enable_static()
38
_global_parallel_strategy = None
39
_global_process_mesh = None
40
_global_process_mesh2 = None
41 42 43


class MLPLayer(nn.Layer):
44

45 46 47 48 49 50 51 52
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
53 54
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
55 56
        bias_attr = None

57 58 59 60 61 62 63 64
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
65 66 67 68
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
69
        if _global_parallel_strategy == "mp":
70 71 72 73 74 75 76 77 78 79
            auto.shard_tensor(self.linear0.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.linear1.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
80
        elif _global_parallel_strategy == "dp_mp":
81 82 83 84 85 86 87 88 89 90
            auto.shard_tensor(self.linear0.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.linear1.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })
91
        elif _global_parallel_strategy == "pp":
92 93 94 95 96 97 98 99 100 101
            auto.shard_tensor(self.linear0.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.linear1.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh2,
                                  "dims_mapping": [1, -1]
                              })
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
118 119 120
        input = static.data(name="input",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
121

122
        if _global_parallel_strategy == "dp":
123 124 125 126 127
            auto.shard_tensor(input,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1, -1]
                              })
128
        elif _global_parallel_strategy == "dp_mp":
129 130 131 132 133 134 135 136 137 138
            auto.shard_tensor(input,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1, -1]
                              })

        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       dropout_ratio=0.1,
                       initializer_range=0.02)
139 140 141 142 143
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
144

145
    def test_mlp_dp(self):
146 147
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
148
        global _global_process_mesh
149
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
150 151 152
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
153 154
        train_program, start_program = mlp_pretrain_forward(
            train_program, start_program)
155 156 157
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
158
        self.assertTrue(dist_context.validate_dist_attr_for_program())
159 160

    def test_mlp_mp(self):
161 162
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
163
        global _global_process_mesh
164
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
165 166 167 168

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
169 170
        train_program, start_program = mlp_pretrain_forward(
            train_program, start_program)
171 172 173
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
174
        self.assertTrue(dist_context.validate_dist_attr_for_program())
175 176

    def test_mlp_dp_mp(self):
177 178
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
179 180
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
181
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
182 183 184 185

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
186 187
        train_program, start_program = mlp_pretrain_forward(
            train_program, start_program)
188 189 190
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
        self.assertTrue(dist_context.validate_dist_attr_for_program())

    # def test_mlp_misc(self):
    #     # import pdb
    #     global _global_parallel_strategy
    #     _global_parallel_strategy = "pp"
    #     global _global_process_mesh
    #     _global_process_mesh = auto.ProcessMesh(
    #         mesh=[[0, 1], [2, 3]])
    #     global _global_process_mesh2
    #     _global_process_mesh2 = auto.ProcessMesh(
    #         mesh=[[4, 5], [6, 7]])

    #     train_program = static.Program()
    #     start_program = static.Program()
    #     dist_context = DistributedContext()
    #     train_program, start_program = mlp_pretrain_forward(train_program,
    #                                                         start_program)
    #     # pdb.set_trace()
210 211
    #    completer = Completer(dist_context)
    #     complete_train_program = auto.completer.complete_forward_annotation(train_program)
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
    #     # print_program_with_dist_attr(complete_train_program,
    #     #                                     dist_context)
    #     dist_context.finalize_distributed_attr_for_program(
    #         complete_train_program)
    #     from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
    #     for block in complete_train_program.blocks:
    #         for tensor in block.vars.values():
    #             desc = tensor.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             attr_name = append_distributed_attr_suffix("dims_mapping")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #         for op in block.ops:
    #             desc = op.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.input_arg_names():
    #                 attr_name = append_distributed_attr_suffix("IN_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.output_arg_names():
    #                 attr_name = append_distributed_attr_suffix("OUT_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #     set_default_distributed_context(dist_context)
    #     self.assertTrue("dist_attr" in str(complete_train_program))
    #     with unittest.mock.patch(
    #             "sys.stdout", new_callable=StringIO) as mock_stdout:
    #         print_program_with_dist_attr(complete_train_program)
    #         self.assertIsNotNone(mock_stdout.getvalue())
242 243 244


class AttentionLayer(nn.Layer):
245

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    def __init__(self,
                 hidden_size=1024,
                 sequence_len=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
267 268
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
269 270
        bias_attr = None

271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
287 288

    def forward(self, input):
289
        if _global_parallel_strategy == "dp":
290 291 292 293 294
            auto.shard_tensor(input,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1, -1]
                              })
295
        elif _global_parallel_strategy == "dp_mp":
296 297 298 299 300
            auto.shard_tensor(input,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1, -1]
                              })
301 302 303 304 305 306 307 308

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

309
        if _global_parallel_strategy == "mp":
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
325
        elif _global_parallel_strategy == "dp_mp":
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
341 342 343 344 345 346 347

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
348 349 350 351
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
352 353 354 355 356 357 358

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
359 360 361 362
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
363 364 365 366 367 368 369 370 371

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
372
        if _global_parallel_strategy == "mp":
373 374 375 376 377
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
378
        elif _global_parallel_strategy == "dp_mp":
379 380 381 382 383
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })
384 385 386 387 388 389 390 391 392 393

        return out


def attn_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
394 395 396 397 398 399 400 401 402
        input = static.data(name="query",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
        attn = AttentionLayer(hidden_size=hidden_size,
                              sequence_len=sequence_len,
                              intermediate_size=4 * hidden_size,
                              num_heads=16,
                              dropout_ratio=0.1,
                              initializer_range=0.02)
403 404 405 406 407 408
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
409

410
    def test_attn_dp(self):
411 412
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
413
        global _global_process_mesh
414
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
415 416 417
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
418 419
        train_program, start_program = attn_pretrain_forward(
            train_program, start_program)
420 421 422
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
423
        # print_program_with_dist_attr(complete_train_program,
424
        #                                     dist_context)
425
        self.assertTrue(dist_context.validate_dist_attr_for_program())
426 427

    def test_attn_mp(self):
428 429
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
430
        global _global_process_mesh
431
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
432 433 434 435

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
436 437
        train_program, start_program = attn_pretrain_forward(
            train_program, start_program)
438 439 440
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
441
        self.assertTrue(dist_context.validate_dist_attr_for_program())
442 443

    def test_attn_dp_mp(self):
444 445
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
446 447
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
448
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
449 450 451 452

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
453 454
        train_program, start_program = attn_pretrain_forward(
            train_program, start_program)
455 456 457
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
458
        self.assertTrue(dist_context.validate_dist_attr_for_program())
459 460 461


class DecoderLayer(nn.Layer):
462

463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
    def __init__(self,
                 vocab_size=32768,
                 hidden_size=1024,
                 sequence_len=512,
                 max_position_embeddings=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
492 493 494 495
            weight_attr=paddle.ParamAttr(name="word_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
496 497 498
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
499 500 501 502
            weight_attr=paddle.ParamAttr(name="pos_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
503 504 505 506

        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
523 524 525 526 527 528 529

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
530 531 532 533 534 535 536 537
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
538 539
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
540 541 542 543 544
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
545
        if _global_parallel_strategy == "dp":
546 547 548 549 550
            auto.shard_tensor(input_ids,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
551
        elif _global_parallel_strategy == "dp_mp":
552 553 554 555 556
            auto.shard_tensor(input_ids,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
557 558 559 560

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

561
        if _global_parallel_strategy == "mp":
562 563 564 565 566
            auto.shard_tensor(self.word_embeddings.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
567
        elif _global_parallel_strategy == "dp_mp":
568 569 570 571 572
            auto.shard_tensor(self.word_embeddings.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })
573 574 575 576 577

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
578
        target = self.norm1(embeddings)
579 580 581 582 583 584 585 586 587

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

588
        if _global_parallel_strategy == "mp":
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
604
        elif _global_parallel_strategy == "dp_mp":
605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
            auto.shard_tensor(self.q_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.k_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.v_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
620 621 622 623 624 625 626

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
627 628 629 630
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
631 632 633 634 635 636 637

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
638 639 640 641
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
642 643 644 645 646 647 648 649 650 651

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

652
        if _global_parallel_strategy == "mp":
653 654 655 656 657
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
658
        elif _global_parallel_strategy == "dp_mp":
659 660 661 662 663
            auto.shard_tensor(self.out_proj.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })
664 665 666 667 668

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
669
        out0 = self.norm2(residual)
670 671 672 673 674 675

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

676
        if _global_parallel_strategy == "mp":
677 678 679 680 681 682 683 684 685 686
            auto.shard_tensor(self.linear0.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 0]
                              })
            auto.shard_tensor(self.linear1.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [0, -1]
                              })
687
        elif _global_parallel_strategy == "dp_mp":
688 689 690 691 692 693 694 695 696 697
            auto.shard_tensor(self.linear0.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [-1, 1]
                              })
            auto.shard_tensor(self.linear1.weight,
                              dist_attr={
                                  "process_mesh": _global_process_mesh,
                                  "dims_mapping": [1, -1]
                              })
698 699 700 701 702 703 704 705 706 707 708 709

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
710 711 712 713 714 715 716 717 718 719 720 721 722 723
        input_ids = static.data(name="input_ids",
                                shape=[batch_size, sequence_len],
                                dtype='int64')
        position_ids = static.data(name="position_ids",
                                   shape=[batch_size, sequence_len],
                                   dtype='int64')
        decoder = DecoderLayer(vocab_size=32768,
                               hidden_size=hidden_size,
                               sequence_len=sequence_len,
                               max_position_embeddings=512,
                               intermediate_size=4 * hidden_size,
                               num_heads=16,
                               dropout_ratio=0.1,
                               initializer_range=0.02)
724 725 726 727 728 729
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
730

731
    def test_decoder_dp(self):
732 733
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
734
        global _global_process_mesh
735
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
736 737 738
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
739 740
        train_program, start_program = decoder_pretrain_forward(
            train_program, start_program)
741 742 743
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
744
        self.assertTrue(dist_context.validate_dist_attr_for_program())
745 746

    def test_decoder_mp(self):
747 748
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
749
        global _global_process_mesh
750
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
751 752 753 754

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
755 756
        train_program, start_program = decoder_pretrain_forward(
            train_program, start_program)
757 758 759
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
760
        self.assertTrue(dist_context.validate_dist_attr_for_program())
761 762

    def test_decoder_dp_mp(self):
763 764
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
765 766
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
767
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
768 769 770 771

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
772 773
        train_program, start_program = decoder_pretrain_forward(
            train_program, start_program)
774 775 776
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
777
        self.assertTrue(dist_context.validate_dist_attr_for_program())
778 779 780 781


if __name__ == "__main__":
    unittest.main()