test_auto_parallel_completion.py 26.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import unittest.mock
from io import StringIO

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
29
from paddle.distributed.fleet import auto
30
from paddle.distributed.auto_parallel.completion import Completer
31
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
32
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
33
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
34 35
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
36

37
paddle.enable_static()
38
_global_parallel_strategy = None
39
_global_process_mesh = None
40
_global_process_mesh2 = None
41 42 43


class MLPLayer(nn.Layer):
44

45 46 47 48 49 50 51 52
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
53 54
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
55 56
        bias_attr = None

57 58 59 60 61 62 63 64
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
65 66 67 68
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
69
        if _global_parallel_strategy in ["mp", "dp_mp"]:
70
            auto.shard_tensor(self.linear0.weight,
71 72
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
73
            auto.shard_tensor(self.linear1.weight,
74 75
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
92 93 94
        input = static.data(name="input",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
95

96
        if _global_parallel_strategy in ["dp", "dp_mp"]:
97
            auto.shard_tensor(input,
98 99
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None, None])
100 101 102 103 104

        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       dropout_ratio=0.1,
                       initializer_range=0.02)
105 106 107 108 109
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
110

111
    def test_mlp_dp(self):
112 113
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
114
        global _global_process_mesh
115 116
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["dp"])
117 118 119
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
120 121
        train_program, start_program = mlp_pretrain_forward(
            train_program, start_program)
122 123 124
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
125
        self.assertTrue(dist_context.validate_dist_attr_for_program())
126 127

    def test_mlp_mp(self):
128 129
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
130
        global _global_process_mesh
131 132
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["mp"])
133 134 135 136

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
137 138
        train_program, start_program = mlp_pretrain_forward(
            train_program, start_program)
139 140 141
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
142
        self.assertTrue(dist_context.validate_dist_attr_for_program())
143 144

    def test_mlp_dp_mp(self):
145 146
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
147
        global _global_process_mesh
148 149 150
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
151 152 153 154

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
155 156
        train_program, start_program = mlp_pretrain_forward(
            train_program, start_program)
157 158 159
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        self.assertTrue(dist_context.validate_dist_attr_for_program())

    # def test_mlp_misc(self):
    #     # import pdb
    #     global _global_parallel_strategy
    #     _global_parallel_strategy = "pp"
    #     global _global_process_mesh
    #     _global_process_mesh = auto.ProcessMesh(
    #         mesh=[[0, 1], [2, 3]])
    #     global _global_process_mesh2
    #     _global_process_mesh2 = auto.ProcessMesh(
    #         mesh=[[4, 5], [6, 7]])

    #     train_program = static.Program()
    #     start_program = static.Program()
    #     dist_context = DistributedContext()
    #     train_program, start_program = mlp_pretrain_forward(train_program,
    #                                                         start_program)
    #     # pdb.set_trace()
179 180
    #    completer = Completer(dist_context)
    #     complete_train_program = auto.completer.complete_forward_annotation(train_program)
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    #     # print_program_with_dist_attr(complete_train_program,
    #     #                                     dist_context)
    #     dist_context.finalize_distributed_attr_for_program(
    #         complete_train_program)
    #     from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
    #     for block in complete_train_program.blocks:
    #         for tensor in block.vars.values():
    #             desc = tensor.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             attr_name = append_distributed_attr_suffix("dims_mapping")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #         for op in block.ops:
    #             desc = op.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.input_arg_names():
    #                 attr_name = append_distributed_attr_suffix("IN_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.output_arg_names():
    #                 attr_name = append_distributed_attr_suffix("OUT_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #     set_default_distributed_context(dist_context)
    #     self.assertTrue("dist_attr" in str(complete_train_program))
    #     with unittest.mock.patch(
    #             "sys.stdout", new_callable=StringIO) as mock_stdout:
    #         print_program_with_dist_attr(complete_train_program)
    #         self.assertIsNotNone(mock_stdout.getvalue())
211 212 213


class AttentionLayer(nn.Layer):
214

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
    def __init__(self,
                 hidden_size=1024,
                 sequence_len=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
236 237
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
238 239
        bias_attr = None

240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
256 257

    def forward(self, input):
258
        if _global_parallel_strategy in ["dp", "dp_mp"]:
259
            auto.shard_tensor(input,
260 261
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None, None])
262 263 264 265 266 267 268 269

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

270
        if _global_parallel_strategy in ["mp", "dp_mp"]:
271
            auto.shard_tensor(self.q_proj.weight,
272 273
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
274
            auto.shard_tensor(self.k_proj.weight,
275 276
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
277
            auto.shard_tensor(self.v_proj.weight,
278 279
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
280 281 282 283 284 285 286

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
287 288 289 290
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
291 292 293 294 295 296 297

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
298 299 300 301
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
302 303 304 305 306 307 308 309 310

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
311
        if _global_parallel_strategy in ["mp", "dp_mp"]:
312
            auto.shard_tensor(self.out_proj.weight,
313 314
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
315 316 317 318 319 320 321 322 323 324

        return out


def attn_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
325 326 327 328 329 330 331 332 333
        input = static.data(name="query",
                            shape=[batch_size, sequence_len, hidden_size],
                            dtype='float32')
        attn = AttentionLayer(hidden_size=hidden_size,
                              sequence_len=sequence_len,
                              intermediate_size=4 * hidden_size,
                              num_heads=16,
                              dropout_ratio=0.1,
                              initializer_range=0.02)
334 335 336 337 338 339
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
340

341
    def test_attn_dp(self):
342 343
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
344
        global _global_process_mesh
345 346
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["dp"])
347 348 349
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
350 351
        train_program, start_program = attn_pretrain_forward(
            train_program, start_program)
352 353 354
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
355
        self.assertTrue(dist_context.validate_dist_attr_for_program())
356 357

    def test_attn_mp(self):
358 359
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
360
        global _global_process_mesh
361 362
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["mp"])
363 364 365 366

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
367 368
        train_program, start_program = attn_pretrain_forward(
            train_program, start_program)
369 370 371
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
372
        self.assertTrue(dist_context.validate_dist_attr_for_program())
373 374

    def test_attn_dp_mp(self):
375 376
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
377
        global _global_process_mesh
378 379 380
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
381 382 383 384

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
385 386
        train_program, start_program = attn_pretrain_forward(
            train_program, start_program)
387 388 389
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
390
        self.assertTrue(dist_context.validate_dist_attr_for_program())
391 392 393


class DecoderLayer(nn.Layer):
394

395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
    def __init__(self,
                 vocab_size=32768,
                 hidden_size=1024,
                 sequence_len=512,
                 max_position_embeddings=512,
                 intermediate_size=4 * 1024,
                 num_heads=16,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
        assert self.head_dim * self.num_heads == self.embed_dim, \
            "embed_dim must be divisible by num_heads"
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
424 425 426 427
            weight_attr=paddle.ParamAttr(name="word_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
428 429 430
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
431 432 433 434
            weight_attr=paddle.ParamAttr(name="pos_embeddings",
                                         initializer=nn.initializer.Normal(
                                             mean=0.0,
                                             std=self.initializer_range)))
435 436 437 438

        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
        self.q_proj = nn.Linear(self.embed_dim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.k_proj = nn.Linear(self.kdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.v_proj = nn.Linear(self.vdim,
                                self.embed_dim,
                                weight_attr,
                                bias_attr=bias_attr)
        self.out_proj = nn.Linear(self.embed_dim,
                                  self.embed_dim,
                                  weight_attr,
                                  bias_attr=bias_attr)
455 456 457 458 459 460 461

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=self.initializer_range))
        bias_attr = None
462 463 464 465 466 467 468 469
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
470 471
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
472 473 474 475 476
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
477
        if _global_parallel_strategy in ["dp", "dp_mp"]:
478
            auto.shard_tensor(input_ids,
479 480
                              process_mesh=_global_process_mesh,
                              shard_spec=["dp", None])
481 482 483 484

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

485
        if _global_parallel_strategy in ["mp", "dp_mp"]:
486
            auto.shard_tensor(self.word_embeddings.weight,
487 488
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
489 490 491 492 493

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
494
        target = self.norm1(embeddings)
495 496 497 498 499 500 501 502 503

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

504
        if _global_parallel_strategy in ["mp", "dp_mp"]:
505
            auto.shard_tensor(self.q_proj.weight,
506 507
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
508
            auto.shard_tensor(self.k_proj.weight,
509 510
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
511
            auto.shard_tensor(self.v_proj.weight,
512 513
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
514 515 516 517 518 519 520

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
521 522 523 524
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)
525 526 527 528 529 530 531

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
532 533 534 535
            weights = F.dropout(weights,
                                self.dropout_ratio,
                                training=self.training,
                                mode="upscale_in_train")
536 537 538 539 540 541 542 543 544 545

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

546
        if _global_parallel_strategy in ["mp", "dp_mp"]:
547
            auto.shard_tensor(self.out_proj.weight,
548 549
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
550 551 552 553 554

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
555
        out0 = self.norm2(residual)
556 557 558 559 560 561

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

562
        if _global_parallel_strategy in ["mp", "dp_mp"]:
563
            auto.shard_tensor(self.linear0.weight,
564 565
                              process_mesh=_global_process_mesh,
                              shard_spec=[None, "mp"])
566
            auto.shard_tensor(self.linear1.weight,
567 568
                              process_mesh=_global_process_mesh,
                              shard_spec=["mp", None])
569 570 571 572 573 574 575 576 577 578 579 580

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
581 582 583 584 585 586 587 588 589 590 591 592 593 594
        input_ids = static.data(name="input_ids",
                                shape=[batch_size, sequence_len],
                                dtype='int64')
        position_ids = static.data(name="position_ids",
                                   shape=[batch_size, sequence_len],
                                   dtype='int64')
        decoder = DecoderLayer(vocab_size=32768,
                               hidden_size=hidden_size,
                               sequence_len=sequence_len,
                               max_position_embeddings=512,
                               intermediate_size=4 * hidden_size,
                               num_heads=16,
                               dropout_ratio=0.1,
                               initializer_range=0.02)
595 596 597 598 599 600
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
601

602
    def test_decoder_dp(self):
603 604
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
605
        global _global_process_mesh
606 607
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["dp"])
608 609 610
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
611 612
        train_program, start_program = decoder_pretrain_forward(
            train_program, start_program)
613 614 615
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
616
        self.assertTrue(dist_context.validate_dist_attr_for_program())
617 618

    def test_decoder_mp(self):
619 620
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
621
        global _global_process_mesh
622 623
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
                                                dim_names=["mp"])
624 625 626 627

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
628 629
        train_program, start_program = decoder_pretrain_forward(
            train_program, start_program)
630 631 632
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
633
        self.assertTrue(dist_context.validate_dist_attr_for_program())
634 635

    def test_decoder_dp_mp(self):
636 637
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
638
        global _global_process_mesh
639 640 641
        _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
                                                      [4, 5, 6, 7]],
                                                dim_names=["dp", "mp"])
642 643 644 645

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
646 647
        train_program, start_program = decoder_pretrain_forward(
            train_program, start_program)
648 649 650
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
651
        self.assertTrue(dist_context.validate_dist_attr_for_program())
652 653 654 655


if __name__ == "__main__":
    unittest.main()