test_auto_parallel_completion.py 28.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import unittest.mock
from io import StringIO

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context
paddle.enable_static()
36
_global_parallel_strategy = None
37
_global_process_mesh = None
38
_global_process_mesh2 = None
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])


class MLPLayer(nn.Layer):
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=initializer_range))
        bias_attr = None

        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
63
        if _global_parallel_strategy == "mp":
64 65 66 67
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
68
        elif _global_parallel_strategy == "dp_mp":
69 70 71 72
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
73 74 75 76 77 78
        elif _global_parallel_strategy == "pp":
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh2,
                dim_mapping=[1, -1])
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32')

100
        if _global_parallel_strategy == "dp":
101 102
            auto.shard_tensor(
                input, _global_process_mesh, dim_mapping=[0, -1, -1])
103
        elif _global_parallel_strategy == "dp_mp":
104 105 106 107 108 109 110 111 112 113 114 115 116 117
            auto.shard_tensor(
                input, _global_process_mesh, dim_mapping=[0, -1, -1])

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
    def test_mlp_dp(self):
118 119
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], parent=ROOT_MESH)
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = mlp_pretrain_forward(train_program,
                                                            start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))

    def test_mlp_mp(self):
137 138
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], parent=ROOT_MESH)

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = mlp_pretrain_forward(train_program,
                                                            start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))

    def test_mlp_dp_mp(self):
157 158
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = mlp_pretrain_forward(train_program,
                                                            start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))

    def test_mlp_misc(self):
177 178 179
        # import pdb
        global _global_parallel_strategy
        _global_parallel_strategy = "pp"
180 181
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
182 183 184 185
            mesh=[[0, 1], [2, 3]], parent=ROOT_MESH)
        global _global_process_mesh2
        _global_process_mesh2 = auto.ProcessMesh(
            mesh=[[4, 5], [6, 7]], parent=ROOT_MESH)
186 187 188 189 190 191

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = mlp_pretrain_forward(train_program,
                                                            start_program)
192
        # pdb.set_trace()
193 194
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
195 196
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
        dist_context.finalize_distributed_attr_for_program(
            complete_train_program)
        from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
        for block in complete_train_program.blocks:
            for tensor in block.vars.values():
                desc = tensor.desc
                attr_name = append_distributed_attr_suffix("mesh_id")
                self.assertIsNotNone(desc.has_attr(attr_name))
                attr_name = append_distributed_attr_suffix("dim_mapping")
                self.assertIsNotNone(desc.has_attr(attr_name))
            for op in block.ops:
                desc = op.desc
                attr_name = append_distributed_attr_suffix("mesh_id")
                self.assertIsNotNone(desc.has_attr(attr_name))
                for tensor_name in desc.input_arg_names():
                    attr_name = append_distributed_attr_suffix("IN_" +
                                                               tensor_name)
                    self.assertIsNotNone(desc.has_attr(attr_name))
                for tensor_name in desc.output_arg_names():
                    attr_name = append_distributed_attr_suffix("OUT_" +
                                                               tensor_name)
                    self.assertIsNotNone(desc.has_attr(attr_name))
        set_default_distributed_context(dist_context)
        self.assertTrue("dist_attr" in str(complete_train_program))
        with unittest.mock.patch(
                "sys.stdout", new_callable=StringIO) as mock_stdout:
            print_program_with_distributed_attr(complete_train_program)
            self.assertIsNotNone(mock_stdout.getvalue())


class AttentionLayer(nn.Layer):
    def __init__(self,
                 hidden_size=1024,
                 sequence_len=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=initializer_range))
        bias_attr = None

        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)

    def forward(self, input):
263
        if _global_parallel_strategy == "dp":
264 265
            auto.shard_tensor(
                input, _global_process_mesh, dim_mapping=[0, -1, -1])
266
        elif _global_parallel_strategy == "dp_mp":
267 268 269 270 271 272 273 274 275 276
            auto.shard_tensor(
                input, _global_process_mesh, dim_mapping=[0, -1, -1])

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

277
        if _global_parallel_strategy == "mp":
278 279 280 281 282 283
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
284
        elif _global_parallel_strategy == "dp_mp":
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
321
        if _global_parallel_strategy == "mp":
322 323 324
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh,
                dim_mapping=[0, -1])
325
        elif _global_parallel_strategy == "dp_mp":
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh,
                dim_mapping=[1, -1])

        return out


def attn_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32')
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
    def test_attn_dp(self):
357 358
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], parent=ROOT_MESH)
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = attn_pretrain_forward(train_program,
                                                             start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))

    def test_attn_mp(self):
376 377
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], parent=ROOT_MESH)

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = attn_pretrain_forward(train_program,
                                                             start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))

    def test_attn_dp_mp(self):
396 397
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = attn_pretrain_forward(train_program,
                                                             start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))


class DecoderLayer(nn.Layer):
    def __init__(self,
                 vocab_size=32768,
                 hidden_size=1024,
                 sequence_len=512,
                 max_position_embeddings=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range)))
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range)))

        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
480 481
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
482 483 484 485 486
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
487
        if _global_parallel_strategy == "dp":
488 489
            auto.shard_tensor(
                input_ids, _global_process_mesh, dim_mapping=[0, -1])
490
        elif _global_parallel_strategy == "dp_mp":
491 492 493 494 495 496
            auto.shard_tensor(
                input_ids, _global_process_mesh, dim_mapping=[0, -1])

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

497
        if _global_parallel_strategy == "mp":
498 499 500 501
            auto.shard_tensor(
                self.word_embeddings.weight,
                _global_process_mesh,
                dim_mapping=[0, -1])
502
        elif _global_parallel_strategy == "dp_mp":
503 504 505 506 507 508 509 510 511
            auto.shard_tensor(
                self.word_embeddings.weight,
                _global_process_mesh,
                dim_mapping=[1, -1])

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
512
        target = self.norm1(embeddings)
513 514 515 516 517 518 519 520 521

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

522
        if _global_parallel_strategy == "mp":
523 524 525 526 527 528
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
529
        elif _global_parallel_strategy == "dp_mp":
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
            auto.shard_tensor(
                self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
            auto.shard_tensor(
                self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
            auto.shard_tensor(
                self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

567
        if _global_parallel_strategy == "mp":
568 569 570
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh,
                dim_mapping=[0, -1])
571
        elif _global_parallel_strategy == "dp_mp":
572 573 574 575 576 577 578 579
            auto.shard_tensor(
                self.out_proj.weight, _global_process_mesh,
                dim_mapping=[1, -1])

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
580
        out0 = self.norm2(residual)
581 582 583 584 585 586

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

587
        if _global_parallel_strategy == "mp":
588 589 590 591
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
592
        elif _global_parallel_strategy == "dp_mp":
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
        position_ids = static.data(
            name="position_ids",
            shape=[batch_size, sequence_len],
            dtype='int64')
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02)
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
    def test_decoder_dp(self):
631 632
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], parent=ROOT_MESH)
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = decoder_pretrain_forward(train_program,
                                                                start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))

    def test_decoder_mp(self):
650 651
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], parent=ROOT_MESH)

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = decoder_pretrain_forward(train_program,
                                                                start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))

    def test_decoder_dp_mp(self):
670 671
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
        train_program, start_program = decoder_pretrain_forward(train_program,
                                                                start_program)
        complete_train_program = auto.complete_annotation(train_program,
                                                          dist_context)
        # print_program_with_distributed_attr(complete_train_program,
        #                                     dist_context)
        self.assertTrue(
            check_distributed_attr_for_program(complete_train_program,
                                               dist_context))


if __name__ == "__main__":
    unittest.main()