test_auto_parallel_completion.py 24.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest.mock

import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
25
from paddle.distributed.fleet import auto
26
from paddle.distributed.auto_parallel.completion import Completer
27
from paddle.distributed.auto_parallel.dist_context import DistributedContext
28

29
paddle.enable_static()
30
_global_parallel_strategy = None
31
_global_process_mesh = None
32
_global_process_mesh2 = None
33 34 35


class MLPLayer(nn.Layer):
36 37 38 39 40 41 42
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
43 44 45
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
46
        weight_attr = paddle.ParamAttr(
47 48
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
49 50
        bias_attr = None

51 52 53 54 55 56
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
57 58 59 60
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
61
        if _global_parallel_strategy in ["mp", "dp_mp"]:
62 63 64 65 66 67 68 69 70 71
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
72 73 74 75 76 77 78 79 80 81 82

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
83 84 85
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
86 87 88
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
89 90 91 92 93
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
94

95
        if _global_parallel_strategy in ["dp", "dp_mp"]:
96 97 98 99 100 101 102 103 104 105 106 107
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
108 109 110 111 112 113
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
    def test_mlp_dp(self):
114 115
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
116
        global _global_process_mesh
117 118 119
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
120 121 122
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
123
        train_program, start_program = mlp_pretrain_forward(
124 125
            train_program, start_program
        )
126 127
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
128 129
            train_program
        )
130
        self.assertTrue(dist_context.validate_dist_attr_for_program())
131 132

    def test_mlp_mp(self):
133 134
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
135
        global _global_process_mesh
136 137 138
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
139 140 141 142

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
143
        train_program, start_program = mlp_pretrain_forward(
144 145
            train_program, start_program
        )
146 147
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
148 149
            train_program
        )
150
        self.assertTrue(dist_context.validate_dist_attr_for_program())
151 152

    def test_mlp_dp_mp(self):
153 154
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
155
        global _global_process_mesh
156 157 158
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
159 160 161 162

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
163
        train_program, start_program = mlp_pretrain_forward(
164 165
            train_program, start_program
        )
166 167
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
168 169
            train_program
        )
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
        self.assertTrue(dist_context.validate_dist_attr_for_program())

    # def test_mlp_misc(self):
    #     # import pdb
    #     global _global_parallel_strategy
    #     _global_parallel_strategy = "pp"
    #     global _global_process_mesh
    #     _global_process_mesh = auto.ProcessMesh(
    #         mesh=[[0, 1], [2, 3]])
    #     global _global_process_mesh2
    #     _global_process_mesh2 = auto.ProcessMesh(
    #         mesh=[[4, 5], [6, 7]])

    #     train_program = static.Program()
    #     start_program = static.Program()
    #     dist_context = DistributedContext()
    #     train_program, start_program = mlp_pretrain_forward(train_program,
    #                                                         start_program)
    #     # pdb.set_trace()
189 190
    #    completer = Completer(dist_context)
    #     complete_train_program = auto.completer.complete_forward_annotation(train_program)
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    #     # print_program_with_dist_attr(complete_train_program,
    #     #                                     dist_context)
    #     dist_context.finalize_distributed_attr_for_program(
    #         complete_train_program)
    #     from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
    #     for block in complete_train_program.blocks:
    #         for tensor in block.vars.values():
    #             desc = tensor.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             attr_name = append_distributed_attr_suffix("dims_mapping")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #         for op in block.ops:
    #             desc = op.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.input_arg_names():
    #                 attr_name = append_distributed_attr_suffix("IN_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.output_arg_names():
    #                 attr_name = append_distributed_attr_suffix("OUT_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #     set_default_distributed_context(dist_context)
    #     self.assertTrue("dist_attr" in str(complete_train_program))
    #     with unittest.mock.patch(
    #             "sys.stdout", new_callable=StringIO) as mock_stdout:
    #         print_program_with_dist_attr(complete_train_program)
    #         self.assertIsNotNone(mock_stdout.getvalue())
221 222 223


class AttentionLayer(nn.Layer):
224 225 226 227 228 229 230 231 232
    def __init__(
        self,
        hidden_size=1024,
        sequence_len=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
233 234 235 236 237 238 239 240
        super(AttentionLayer, self).__init__()
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
241 242 243
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
244 245 246 247
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
248
        weight_attr = paddle.ParamAttr(
249 250
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
251 252
        bias_attr = None

253 254 255 256 257 258 259 260 261 262 263 264
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
265 266

    def forward(self, input):
267
        if _global_parallel_strategy in ["dp", "dp_mp"]:
268 269 270 271 272
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )
273 274 275 276 277 278 279 280

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

281
        if _global_parallel_strategy in ["mp", "dp_mp"]:
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
297 298 299 300 301 302 303

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
304 305 306
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
        )
307 308 309 310 311 312 313

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
314 315 316 317 318 319
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
320 321 322 323 324 325 326 327 328

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
329
        if _global_parallel_strategy in ["mp", "dp_mp"]:
330 331 332 333 334
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
335 336 337 338 339

        return out


def attn_pretrain_forward(train_program, start_program):
340 341 342
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
343 344 345
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
346 347 348 349 350 351 352 353 354 355 356 357 358
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
359 360 361 362 363 364 365
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
    def test_attn_dp(self):
366 367
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
368
        global _global_process_mesh
369 370 371
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
372 373 374
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
375
        train_program, start_program = attn_pretrain_forward(
376 377
            train_program, start_program
        )
378 379
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
380 381
            train_program
        )
382
        self.assertTrue(dist_context.validate_dist_attr_for_program())
383 384

    def test_attn_mp(self):
385 386
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
387
        global _global_process_mesh
388 389 390
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
391 392 393 394

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
395
        train_program, start_program = attn_pretrain_forward(
396 397
            train_program, start_program
        )
398 399
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
400 401
            train_program
        )
402
        self.assertTrue(dist_context.validate_dist_attr_for_program())
403 404

    def test_attn_dp_mp(self):
405 406
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
407
        global _global_process_mesh
408 409 410
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
411 412 413 414

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
415
        train_program, start_program = attn_pretrain_forward(
416 417
            train_program, start_program
        )
418 419
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
420 421
            train_program
        )
422
        self.assertTrue(dist_context.validate_dist_attr_for_program())
423 424 425


class DecoderLayer(nn.Layer):
426 427 428 429 430 431 432 433 434 435 436
    def __init__(
        self,
        vocab_size=32768,
        hidden_size=1024,
        sequence_len=512,
        max_position_embeddings=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
        super(DecoderLayer, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
452 453 454
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
455 456 457
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
458 459 460 461 462 463 464
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
465 466 467
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
468 469 470 471 472 473 474
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
475

476 477 478 479 480
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
481
        bias_attr = None
482 483 484 485 486 487 488 489 490 491 492 493
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
494 495 496 497

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
498 499 500 501 502
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
503
        bias_attr = None
504 505 506 507 508 509
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
510 511
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
512 513 514 515 516
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
517
        if _global_parallel_strategy in ["dp", "dp_mp"]:
518 519 520 521 522
            auto.shard_tensor(
                input_ids,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None],
            )
523 524 525 526

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

527
        if _global_parallel_strategy in ["mp", "dp_mp"]:
528 529 530 531 532
            auto.shard_tensor(
                self.word_embeddings.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
533 534 535 536 537

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
538
        target = self.norm1(embeddings)
539 540 541 542 543 544 545 546 547

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

548
        if _global_parallel_strategy in ["mp", "dp_mp"]:
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
564 565 566 567 568 569 570

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
571 572 573
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5
        )
574 575 576 577 578 579 580

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
581 582 583 584 585 586
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
587 588 589 590 591 592 593 594 595 596

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

597
        if _global_parallel_strategy in ["mp", "dp_mp"]:
598 599 600 601 602
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
603 604 605 606 607

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
608
        out0 = self.norm2(residual)
609 610 611 612 613 614

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

615
        if _global_parallel_strategy in ["mp", "dp_mp"]:
616 617 618 619 620 621 622 623 624 625
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
626 627 628 629 630 631 632

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
633 634 635
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
636 637 638
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        position_ids = static.data(
            name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
655 656 657 658 659 660 661
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
    def test_decoder_dp(self):
662 663
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
664
        global _global_process_mesh
665 666 667
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
668 669 670
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
671
        train_program, start_program = decoder_pretrain_forward(
672 673
            train_program, start_program
        )
674 675
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
676 677
            train_program
        )
678
        self.assertTrue(dist_context.validate_dist_attr_for_program())
679 680

    def test_decoder_mp(self):
681 682
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
683
        global _global_process_mesh
684 685 686
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
687 688 689 690

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
691
        train_program, start_program = decoder_pretrain_forward(
692 693
            train_program, start_program
        )
694 695
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
696 697
            train_program
        )
698
        self.assertTrue(dist_context.validate_dist_attr_for_program())
699 700

    def test_decoder_dp_mp(self):
701 702
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
703
        global _global_process_mesh
704 705 706
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
707 708 709 710

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
711
        train_program, start_program = decoder_pretrain_forward(
712 713
            train_program, start_program
        )
714 715
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
716 717
            train_program
        )
718
        self.assertTrue(dist_context.validate_dist_attr_for_program())
719 720 721 722


if __name__ == "__main__":
    unittest.main()