test_auto_parallel_completion.py 24.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest.mock

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
21
import paddle.static as static
22
import paddle.tensor as tensor
23
import paddle.utils as utils
24
from paddle.distributed.auto_parallel.completion import Completer
25
from paddle.distributed.auto_parallel.dist_context import DistributedContext
26
from paddle.distributed.fleet import auto
27

28
paddle.enable_static()
29
_global_parallel_strategy = None
30
_global_process_mesh = None
31
_global_process_mesh2 = None
32 33 34


class MLPLayer(nn.Layer):
35 36 37 38 39 40 41
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
42
        super().__init__()
43 44
        d_model = hidden_size
        dim_feedforward = intermediate_size
45
        weight_attr = paddle.ParamAttr(
46 47
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
48 49
        bias_attr = None

50 51 52 53 54 55
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
56 57 58 59
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
60
        if _global_parallel_strategy in ["mp", "dp_mp"]:
61 62 63 64 65 66 67 68 69 70
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
71 72 73 74 75 76 77 78 79 80 81

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
82 83 84
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
85 86 87
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
88 89 90 91 92
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
93

94
        if _global_parallel_strategy in ["dp", "dp_mp"]:
95 96 97 98 99 100 101 102 103 104 105 106
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
107 108 109 110 111 112
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
    def test_mlp_dp(self):
113 114
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
115
        global _global_process_mesh
116 117 118
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
119 120 121
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
122
        train_program, start_program = mlp_pretrain_forward(
123 124
            train_program, start_program
        )
125 126
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
127 128
            train_program
        )
129
        self.assertTrue(dist_context.validate_dist_attr_for_program())
130 131

    def test_mlp_mp(self):
132 133
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
134
        global _global_process_mesh
135 136 137
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
138 139 140 141

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
142
        train_program, start_program = mlp_pretrain_forward(
143 144
            train_program, start_program
        )
145 146
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
147 148
            train_program
        )
149
        self.assertTrue(dist_context.validate_dist_attr_for_program())
150 151

    def test_mlp_dp_mp(self):
152 153
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
154
        global _global_process_mesh
155 156 157
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
158 159 160 161

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
162
        train_program, start_program = mlp_pretrain_forward(
163 164
            train_program, start_program
        )
165 166
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
167 168
            train_program
        )
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        self.assertTrue(dist_context.validate_dist_attr_for_program())

    # def test_mlp_misc(self):
    #     # import pdb
    #     global _global_parallel_strategy
    #     _global_parallel_strategy = "pp"
    #     global _global_process_mesh
    #     _global_process_mesh = auto.ProcessMesh(
    #         mesh=[[0, 1], [2, 3]])
    #     global _global_process_mesh2
    #     _global_process_mesh2 = auto.ProcessMesh(
    #         mesh=[[4, 5], [6, 7]])

    #     train_program = static.Program()
    #     start_program = static.Program()
    #     dist_context = DistributedContext()
    #     train_program, start_program = mlp_pretrain_forward(train_program,
    #                                                         start_program)
    #     # pdb.set_trace()
188 189
    #    completer = Completer(dist_context)
    #     complete_train_program = auto.completer.complete_forward_annotation(train_program)
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    #     # print_program_with_dist_attr(complete_train_program,
    #     #                                     dist_context)
    #     dist_context.finalize_distributed_attr_for_program(
    #         complete_train_program)
    #     from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
    #     for block in complete_train_program.blocks:
    #         for tensor in block.vars.values():
    #             desc = tensor.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             attr_name = append_distributed_attr_suffix("dims_mapping")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #         for op in block.ops:
    #             desc = op.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.input_arg_names():
    #                 attr_name = append_distributed_attr_suffix("IN_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.output_arg_names():
    #                 attr_name = append_distributed_attr_suffix("OUT_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #     set_default_distributed_context(dist_context)
    #     self.assertTrue("dist_attr" in str(complete_train_program))
    #     with unittest.mock.patch(
    #             "sys.stdout", new_callable=StringIO) as mock_stdout:
    #         print_program_with_dist_attr(complete_train_program)
    #         self.assertIsNotNone(mock_stdout.getvalue())
220 221 222


class AttentionLayer(nn.Layer):
223 224 225 226 227 228 229 230 231
    def __init__(
        self,
        hidden_size=1024,
        sequence_len=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
232
        super().__init__()
233 234 235 236 237 238 239
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
240 241 242
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
243 244 245 246
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
247
        weight_attr = paddle.ParamAttr(
248 249
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
250 251
        bias_attr = None

252 253 254 255 256 257 258 259 260 261 262 263
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
264 265

    def forward(self, input):
266
        if _global_parallel_strategy in ["dp", "dp_mp"]:
267 268 269 270 271
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )
272 273 274 275 276 277 278 279

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

280
        if _global_parallel_strategy in ["mp", "dp_mp"]:
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
296 297 298 299 300 301 302

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
K
kangguangli 已提交
303 304
        product = tensor.matmul(x=q, y=k, transpose_y=True)
        product = tensor.scale(product, scale=self.head_dim**-0.5)
305 306 307 308 309 310 311

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
312 313 314 315 316 317
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
318 319 320 321 322 323 324 325 326

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
327
        if _global_parallel_strategy in ["mp", "dp_mp"]:
328 329 330 331 332
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
333 334 335 336 337

        return out


def attn_pretrain_forward(train_program, start_program):
338 339 340
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
341 342 343
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
344 345 346 347 348 349 350 351 352 353 354 355 356
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
357 358 359 360 361 362 363
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
    def test_attn_dp(self):
364 365
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
366
        global _global_process_mesh
367 368 369
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
370 371 372
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
373
        train_program, start_program = attn_pretrain_forward(
374 375
            train_program, start_program
        )
376 377
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
378 379
            train_program
        )
380
        self.assertTrue(dist_context.validate_dist_attr_for_program())
381 382

    def test_attn_mp(self):
383 384
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
385
        global _global_process_mesh
386 387 388
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
389 390 391 392

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
393
        train_program, start_program = attn_pretrain_forward(
394 395
            train_program, start_program
        )
396 397
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
398 399
            train_program
        )
400
        self.assertTrue(dist_context.validate_dist_attr_for_program())
401 402

    def test_attn_dp_mp(self):
403 404
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
405
        global _global_process_mesh
406 407 408
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
409 410 411 412

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
413
        train_program, start_program = attn_pretrain_forward(
414 415
            train_program, start_program
        )
416 417
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
418 419
            train_program
        )
420
        self.assertTrue(dist_context.validate_dist_attr_for_program())
421 422 423


class DecoderLayer(nn.Layer):
424 425 426 427 428 429 430 431 432 433 434
    def __init__(
        self,
        vocab_size=32768,
        hidden_size=1024,
        sequence_len=512,
        max_position_embeddings=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
435
        super().__init__()
436 437 438 439 440 441 442 443 444 445 446 447 448 449
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
450 451 452
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
453 454 455
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
456 457 458 459 460 461 462
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
463 464 465
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
466 467 468 469 470 471 472
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
473

474 475 476 477 478
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
479
        bias_attr = None
480 481 482 483 484 485 486 487 488 489 490 491
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
492 493 494 495

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
496 497 498 499 500
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
501
        bias_attr = None
502 503 504 505 506 507
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
508 509
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
510 511 512 513 514
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
515
        if _global_parallel_strategy in ["dp", "dp_mp"]:
516 517 518 519 520
            auto.shard_tensor(
                input_ids,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None],
            )
521 522 523 524

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

525
        if _global_parallel_strategy in ["mp", "dp_mp"]:
526 527 528 529 530
            auto.shard_tensor(
                self.word_embeddings.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
531 532 533 534 535

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
536
        target = self.norm1(embeddings)
537 538 539 540 541 542 543 544 545

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

546
        if _global_parallel_strategy in ["mp", "dp_mp"]:
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
562 563 564 565 566 567 568

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
K
kangguangli 已提交
569 570
        product = tensor.matmul(x=q, y=k, transpose_y=True)
        product = tensor.scale(product, scale=self.head_dim**-0.5)
571 572 573 574 575 576 577

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
578 579 580 581 582 583
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
584 585 586 587 588 589 590 591 592 593

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

594
        if _global_parallel_strategy in ["mp", "dp_mp"]:
595 596 597 598 599
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
600 601 602 603 604

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
605
        out0 = self.norm2(residual)
606 607 608 609 610 611

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

612
        if _global_parallel_strategy in ["mp", "dp_mp"]:
613 614 615 616 617 618 619 620 621 622
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
623 624 625 626 627 628 629

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
630 631 632
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
633 634 635
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        position_ids = static.data(
            name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
652 653 654 655 656 657 658
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
    def test_decoder_dp(self):
659 660
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
661
        global _global_process_mesh
662 663 664
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
665 666 667
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
668
        train_program, start_program = decoder_pretrain_forward(
669 670
            train_program, start_program
        )
671 672
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
673 674
            train_program
        )
675
        self.assertTrue(dist_context.validate_dist_attr_for_program())
676 677

    def test_decoder_mp(self):
678 679
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
680
        global _global_process_mesh
681 682 683
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
684 685 686 687

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
688
        train_program, start_program = decoder_pretrain_forward(
689 690
            train_program, start_program
        )
691 692
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
693 694
            train_program
        )
695
        self.assertTrue(dist_context.validate_dist_attr_for_program())
696 697

    def test_decoder_dp_mp(self):
698 699
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
700
        global _global_process_mesh
701 702 703
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
704 705 706 707

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
708
        train_program, start_program = decoder_pretrain_forward(
709 710
            train_program, start_program
        )
711 712
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
713 714
            train_program
        )
715
        self.assertTrue(dist_context.validate_dist_attr_for_program())
716 717 718 719


if __name__ == "__main__":
    unittest.main()