test_auto_parallel_completion.py 24.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest.mock

import paddle
import paddle.nn.functional as F
20
from paddle import nn, static, tensor, utils
21 22 23 24
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_context import (
    DistributedContext,
)
25
from paddle.distributed.fleet import auto
26

27
paddle.enable_static()
28
_global_parallel_strategy = None
29
_global_process_mesh = None
30
_global_process_mesh2 = None
31 32 33


class MLPLayer(nn.Layer):
34 35 36 37 38 39 40
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
41
        super().__init__()
42 43
        d_model = hidden_size
        dim_feedforward = intermediate_size
44
        weight_attr = paddle.ParamAttr(
45 46
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
47 48
        bias_attr = None

49 50 51 52 53 54
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
55 56 57 58
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
59
        if _global_parallel_strategy in ["mp", "dp_mp"]:
60 61 62 63 64 65 66 67 68 69
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
70 71 72 73 74 75 76 77 78 79 80

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
81 82 83
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
84 85 86
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
87 88 89 90 91
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
92

93
        if _global_parallel_strategy in ["dp", "dp_mp"]:
94 95 96 97 98 99 100 101 102 103 104 105
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
106 107 108 109 110 111
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
    def test_mlp_dp(self):
112 113
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
114
        global _global_process_mesh
115 116 117
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
118 119 120
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
121
        train_program, start_program = mlp_pretrain_forward(
122 123
            train_program, start_program
        )
124 125
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
126 127
            train_program
        )
128
        self.assertTrue(dist_context.validate_dist_attr_for_program())
129 130

    def test_mlp_mp(self):
131 132
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
133
        global _global_process_mesh
134 135 136
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
137 138 139 140

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
141
        train_program, start_program = mlp_pretrain_forward(
142 143
            train_program, start_program
        )
144 145
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
146 147
            train_program
        )
148
        self.assertTrue(dist_context.validate_dist_attr_for_program())
149 150

    def test_mlp_dp_mp(self):
151 152
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
153
        global _global_process_mesh
154 155 156
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
157 158 159 160

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
161
        train_program, start_program = mlp_pretrain_forward(
162 163
            train_program, start_program
        )
164 165
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
166 167
            train_program
        )
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
        self.assertTrue(dist_context.validate_dist_attr_for_program())

    # def test_mlp_misc(self):
    #     # import pdb
    #     global _global_parallel_strategy
    #     _global_parallel_strategy = "pp"
    #     global _global_process_mesh
    #     _global_process_mesh = auto.ProcessMesh(
    #         mesh=[[0, 1], [2, 3]])
    #     global _global_process_mesh2
    #     _global_process_mesh2 = auto.ProcessMesh(
    #         mesh=[[4, 5], [6, 7]])

    #     train_program = static.Program()
    #     start_program = static.Program()
    #     dist_context = DistributedContext()
    #     train_program, start_program = mlp_pretrain_forward(train_program,
    #                                                         start_program)
    #     # pdb.set_trace()
187 188
    #    completer = Completer(dist_context)
    #     complete_train_program = auto.completer.complete_forward_annotation(train_program)
189 190 191 192
    #     # print_program_with_dist_attr(complete_train_program,
    #     #                                     dist_context)
    #     dist_context.finalize_distributed_attr_for_program(
    #         complete_train_program)
193
    #     from paddle.distributed.auto_parallel.static.interface import _g_process_mesh_map
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
    #     for block in complete_train_program.blocks:
    #         for tensor in block.vars.values():
    #             desc = tensor.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             attr_name = append_distributed_attr_suffix("dims_mapping")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #         for op in block.ops:
    #             desc = op.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.input_arg_names():
    #                 attr_name = append_distributed_attr_suffix("IN_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.output_arg_names():
    #                 attr_name = append_distributed_attr_suffix("OUT_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #     set_default_distributed_context(dist_context)
    #     self.assertTrue("dist_attr" in str(complete_train_program))
    #     with unittest.mock.patch(
    #             "sys.stdout", new_callable=StringIO) as mock_stdout:
    #         print_program_with_dist_attr(complete_train_program)
    #         self.assertIsNotNone(mock_stdout.getvalue())
219 220 221


class AttentionLayer(nn.Layer):
222 223 224 225 226 227 228 229 230
    def __init__(
        self,
        hidden_size=1024,
        sequence_len=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
231
        super().__init__()
232 233 234 235 236 237 238
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
239 240 241
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
242 243 244 245
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
246
        weight_attr = paddle.ParamAttr(
247 248
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
249 250
        bias_attr = None

251 252 253 254 255 256 257 258 259 260 261 262
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
263 264

    def forward(self, input):
265
        if _global_parallel_strategy in ["dp", "dp_mp"]:
266 267 268 269 270
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )
271 272 273 274 275 276 277 278

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

279
        if _global_parallel_strategy in ["mp", "dp_mp"]:
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
295 296 297 298 299 300 301

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
K
kangguangli 已提交
302 303
        product = tensor.matmul(x=q, y=k, transpose_y=True)
        product = tensor.scale(product, scale=self.head_dim**-0.5)
304 305 306 307 308 309 310

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
311 312 313 314 315 316
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
317 318 319 320 321 322 323 324 325

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
326
        if _global_parallel_strategy in ["mp", "dp_mp"]:
327 328 329 330 331
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
332 333 334 335 336

        return out


def attn_pretrain_forward(train_program, start_program):
337 338 339
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
340 341 342
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
343 344 345 346 347 348 349 350 351 352 353 354 355
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
356 357 358 359 360 361 362
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
    def test_attn_dp(self):
363 364
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
365
        global _global_process_mesh
366 367 368
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
369 370 371
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
372
        train_program, start_program = attn_pretrain_forward(
373 374
            train_program, start_program
        )
375 376
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
377 378
            train_program
        )
379
        self.assertTrue(dist_context.validate_dist_attr_for_program())
380 381

    def test_attn_mp(self):
382 383
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
384
        global _global_process_mesh
385 386 387
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
388 389 390 391

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
392
        train_program, start_program = attn_pretrain_forward(
393 394
            train_program, start_program
        )
395 396
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
397 398
            train_program
        )
399
        self.assertTrue(dist_context.validate_dist_attr_for_program())
400 401

    def test_attn_dp_mp(self):
402 403
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
404
        global _global_process_mesh
405 406 407
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
408 409 410 411

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
412
        train_program, start_program = attn_pretrain_forward(
413 414
            train_program, start_program
        )
415 416
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
417 418
            train_program
        )
419
        self.assertTrue(dist_context.validate_dist_attr_for_program())
420 421 422


class DecoderLayer(nn.Layer):
423 424 425 426 427 428 429 430 431 432 433
    def __init__(
        self,
        vocab_size=32768,
        hidden_size=1024,
        sequence_len=512,
        max_position_embeddings=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
434
        super().__init__()
435 436 437 438 439 440 441 442 443 444 445 446 447 448
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
449 450 451
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
452 453 454
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
455 456 457 458 459 460 461
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
462 463 464
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
465 466 467 468 469 470 471
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
472

473 474 475 476 477
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
478
        bias_attr = None
479 480 481 482 483 484 485 486 487 488 489 490
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
491 492 493 494

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
495 496 497 498 499
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
500
        bias_attr = None
501 502 503 504 505 506
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
507 508
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
509 510 511 512 513
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
514
        if _global_parallel_strategy in ["dp", "dp_mp"]:
515 516 517 518 519
            auto.shard_tensor(
                input_ids,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None],
            )
520 521 522 523

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

524
        if _global_parallel_strategy in ["mp", "dp_mp"]:
525 526 527 528 529
            auto.shard_tensor(
                self.word_embeddings.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
530 531 532 533 534

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
535
        target = self.norm1(embeddings)
536 537 538 539 540 541 542 543 544

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

545
        if _global_parallel_strategy in ["mp", "dp_mp"]:
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
561 562 563 564 565 566 567

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
K
kangguangli 已提交
568 569
        product = tensor.matmul(x=q, y=k, transpose_y=True)
        product = tensor.scale(product, scale=self.head_dim**-0.5)
570 571 572 573 574 575 576

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
577 578 579 580 581 582
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
583 584 585 586 587 588 589 590 591 592

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

593
        if _global_parallel_strategy in ["mp", "dp_mp"]:
594 595 596 597 598
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
599 600 601 602 603

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
604
        out0 = self.norm2(residual)
605 606 607 608 609 610

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

611
        if _global_parallel_strategy in ["mp", "dp_mp"]:
612 613 614 615 616 617 618 619 620 621
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
622 623 624 625 626 627 628

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
629 630 631
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
632 633 634
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        position_ids = static.data(
            name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
651 652 653 654 655 656 657
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
    def test_decoder_dp(self):
658 659
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
660
        global _global_process_mesh
661 662 663
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
664 665 666
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
667
        train_program, start_program = decoder_pretrain_forward(
668 669
            train_program, start_program
        )
670 671
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
672 673
            train_program
        )
674
        self.assertTrue(dist_context.validate_dist_attr_for_program())
675 676

    def test_decoder_mp(self):
677 678
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
679
        global _global_process_mesh
680 681 682
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
683 684 685 686

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
687
        train_program, start_program = decoder_pretrain_forward(
688 689
            train_program, start_program
        )
690 691
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
692 693
            train_program
        )
694
        self.assertTrue(dist_context.validate_dist_attr_for_program())
695 696

    def test_decoder_dp_mp(self):
697 698
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
699
        global _global_process_mesh
700 701 702
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
703 704 705 706

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
707
        train_program, start_program = decoder_pretrain_forward(
708 709
            train_program, start_program
        )
710 711
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
712 713
            train_program
        )
714
        self.assertTrue(dist_context.validate_dist_attr_for_program())
715 716 717 718


if __name__ == "__main__":
    unittest.main()