test_auto_parallel_completion.py 24.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest.mock

import paddle
import paddle.nn.functional as F
20
from paddle import nn, static, tensor, utils
21
from paddle.distributed.auto_parallel.completion import Completer
22
from paddle.distributed.auto_parallel.dist_context import DistributedContext
23
from paddle.distributed.fleet import auto
24

25
paddle.enable_static()
26
_global_parallel_strategy = None
27
_global_process_mesh = None
28
_global_process_mesh2 = None
29 30 31


class MLPLayer(nn.Layer):
32 33 34 35 36 37 38
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
39
        super().__init__()
40 41
        d_model = hidden_size
        dim_feedforward = intermediate_size
42
        weight_attr = paddle.ParamAttr(
43 44
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
45 46
        bias_attr = None

47 48 49 50 51 52
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
53 54 55 56
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
57
        if _global_parallel_strategy in ["mp", "dp_mp"]:
58 59 60 61 62 63 64 65 66 67
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
68 69 70 71 72 73 74 75 76 77 78

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)
        out = self.dropout(out)

        return out


def mlp_pretrain_forward(train_program, start_program):
79 80 81
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
82 83 84
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
85 86 87 88 89
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
90

91
        if _global_parallel_strategy in ["dp", "dp_mp"]:
92 93 94 95 96 97 98 99 100 101 102 103
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
104 105 106 107 108 109
        out = mlp(input)
    return train_program, start_program


class TestMLPAutoCompletion(unittest.TestCase):
    def test_mlp_dp(self):
110 111
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
112
        global _global_process_mesh
113 114 115
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
116 117 118
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
119
        train_program, start_program = mlp_pretrain_forward(
120 121
            train_program, start_program
        )
122 123
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
124 125
            train_program
        )
126
        self.assertTrue(dist_context.validate_dist_attr_for_program())
127 128

    def test_mlp_mp(self):
129 130
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
131
        global _global_process_mesh
132 133 134
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
135 136 137 138

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
139
        train_program, start_program = mlp_pretrain_forward(
140 141
            train_program, start_program
        )
142 143
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
144 145
            train_program
        )
146
        self.assertTrue(dist_context.validate_dist_attr_for_program())
147 148

    def test_mlp_dp_mp(self):
149 150
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
151
        global _global_process_mesh
152 153 154
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
155 156 157 158

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
159
        train_program, start_program = mlp_pretrain_forward(
160 161
            train_program, start_program
        )
162 163
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
164 165
            train_program
        )
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
        self.assertTrue(dist_context.validate_dist_attr_for_program())

    # def test_mlp_misc(self):
    #     # import pdb
    #     global _global_parallel_strategy
    #     _global_parallel_strategy = "pp"
    #     global _global_process_mesh
    #     _global_process_mesh = auto.ProcessMesh(
    #         mesh=[[0, 1], [2, 3]])
    #     global _global_process_mesh2
    #     _global_process_mesh2 = auto.ProcessMesh(
    #         mesh=[[4, 5], [6, 7]])

    #     train_program = static.Program()
    #     start_program = static.Program()
    #     dist_context = DistributedContext()
    #     train_program, start_program = mlp_pretrain_forward(train_program,
    #                                                         start_program)
    #     # pdb.set_trace()
185 186
    #    completer = Completer(dist_context)
    #     complete_train_program = auto.completer.complete_forward_annotation(train_program)
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
    #     # print_program_with_dist_attr(complete_train_program,
    #     #                                     dist_context)
    #     dist_context.finalize_distributed_attr_for_program(
    #         complete_train_program)
    #     from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
    #     for block in complete_train_program.blocks:
    #         for tensor in block.vars.values():
    #             desc = tensor.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             attr_name = append_distributed_attr_suffix("dims_mapping")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #         for op in block.ops:
    #             desc = op.desc
    #             attr_name = append_distributed_attr_suffix("mesh_id")
    #             self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.input_arg_names():
    #                 attr_name = append_distributed_attr_suffix("IN_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #             for tensor_name in desc.output_arg_names():
    #                 attr_name = append_distributed_attr_suffix("OUT_" +
    #                                                            tensor_name)
    #                 self.assertIsNotNone(desc.has_attr(attr_name))
    #     set_default_distributed_context(dist_context)
    #     self.assertTrue("dist_attr" in str(complete_train_program))
    #     with unittest.mock.patch(
    #             "sys.stdout", new_callable=StringIO) as mock_stdout:
    #         print_program_with_dist_attr(complete_train_program)
    #         self.assertIsNotNone(mock_stdout.getvalue())
217 218 219


class AttentionLayer(nn.Layer):
220 221 222 223 224 225 226 227 228
    def __init__(
        self,
        hidden_size=1024,
        sequence_len=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
229
        super().__init__()
230 231 232 233 234 235 236
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
237 238 239
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
240 241 242 243
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None
244
        weight_attr = paddle.ParamAttr(
245 246
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
247 248
        bias_attr = None

249 250 251 252 253 254 255 256 257 258 259 260
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
261 262

    def forward(self, input):
263
        if _global_parallel_strategy in ["dp", "dp_mp"]:
264 265 266 267 268
            auto.shard_tensor(
                input,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None, None],
            )
269 270 271 272 273 274 275 276

        q = self.q_proj(input)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(input)
        v = self.v_proj(input)

277
        if _global_parallel_strategy in ["mp", "dp_mp"]:
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
293 294 295 296 297 298 299

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
K
kangguangli 已提交
300 301
        product = tensor.matmul(x=q, y=k, transpose_y=True)
        product = tensor.scale(product, scale=self.head_dim**-0.5)
302 303 304 305 306 307 308

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
309 310 311 312 313 314
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
315 316 317 318 319 320 321 322 323

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)
324
        if _global_parallel_strategy in ["mp", "dp_mp"]:
325 326 327 328 329
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
330 331 332 333 334

        return out


def attn_pretrain_forward(train_program, start_program):
335 336 337
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
338 339 340
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
341 342 343 344 345 346 347 348 349 350 351 352 353
        input = static.data(
            name="query",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
        attn = AttentionLayer(
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
354 355 356 357 358 359 360
        out = attn(input)

    return train_program, start_program


class TestAttentionAutoCompletion(unittest.TestCase):
    def test_attn_dp(self):
361 362
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
363
        global _global_process_mesh
364 365 366
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
367 368 369
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
370
        train_program, start_program = attn_pretrain_forward(
371 372
            train_program, start_program
        )
373 374
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
375 376
            train_program
        )
377
        self.assertTrue(dist_context.validate_dist_attr_for_program())
378 379

    def test_attn_mp(self):
380 381
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
382
        global _global_process_mesh
383 384 385
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
386 387 388 389

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
390
        train_program, start_program = attn_pretrain_forward(
391 392
            train_program, start_program
        )
393 394
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
395 396
            train_program
        )
397
        self.assertTrue(dist_context.validate_dist_attr_for_program())
398 399

    def test_attn_dp_mp(self):
400 401
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
402
        global _global_process_mesh
403 404 405
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
406 407 408 409

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
410
        train_program, start_program = attn_pretrain_forward(
411 412
            train_program, start_program
        )
413 414
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
415 416
            train_program
        )
417
        self.assertTrue(dist_context.validate_dist_attr_for_program())
418 419 420


class DecoderLayer(nn.Layer):
421 422 423 424 425 426 427 428 429 430 431
    def __init__(
        self,
        vocab_size=32768,
        hidden_size=1024,
        sequence_len=512,
        max_position_embeddings=512,
        intermediate_size=4 * 1024,
        num_heads=16,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
432
        super().__init__()
433 434 435 436 437 438 439 440 441 442 443 444 445 446
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.sequence_len = sequence_len
        self.embed_dim = self.hidden_size
        self.kdim = self.embed_dim
        self.vdim = self.embed_dim
        self.num_heads = num_heads
        self.dropout_ratio = dropout_ratio
        self.initializer_range = initializer_range
        self.training = True
        self.attn_mask = None

        self.head_dim = self.embed_dim // self.num_heads
447 448 449
        assert (
            self.head_dim * self.num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
450 451 452
        self.word_embeddings = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
453 454 455 456 457 458 459
            weight_attr=paddle.ParamAttr(
                name="word_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
460 461 462
        self.position_embeddings = nn.Embedding(
            self.max_position_embeddings,
            self.hidden_size,
463 464 465 466 467 468 469
            weight_attr=paddle.ParamAttr(
                name="pos_embeddings",
                initializer=nn.initializer.Normal(
                    mean=0.0, std=self.initializer_range
                ),
            ),
        )
470

471 472 473 474 475
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
476
        bias_attr = None
477 478 479 480 481 482 483 484 485 486 487 488
        self.q_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.k_proj = nn.Linear(
            self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.v_proj = nn.Linear(
            self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
        self.out_proj = nn.Linear(
            self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr
        )
489 490 491 492

        intermediate_size = 4 * self.hidden_size
        d_model = self.hidden_size
        dim_feedforward = intermediate_size
493 494 495 496 497
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(
                mean=0.0, std=self.initializer_range
            )
        )
498
        bias_attr = None
499 500 501 502 503 504
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
505 506
        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
507 508 509 510 511
        self.dropout1 = nn.Dropout(self.dropout_ratio)
        self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
        self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")

    def forward(self, input_ids, position_ids):
512
        if _global_parallel_strategy in ["dp", "dp_mp"]:
513 514 515 516 517
            auto.shard_tensor(
                input_ids,
                process_mesh=_global_process_mesh,
                shard_spec=["dp", None],
            )
518 519 520 521

        input_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

522
        if _global_parallel_strategy in ["mp", "dp_mp"]:
523 524 525 526 527
            auto.shard_tensor(
                self.word_embeddings.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
528 529 530 531 532

        embeddings = input_embeddings + position_embeddings
        embeddings = self.dropout1(embeddings)

        # Pre-norm
533
        target = self.norm1(embeddings)
534 535 536 537 538 539 540 541 542

        # The following is the attention part
        q = self.q_proj(target)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        k = self.k_proj(target)
        v = self.v_proj(target)

543
        if _global_parallel_strategy in ["mp", "dp_mp"]:
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
            auto.shard_tensor(
                self.q_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.k_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.v_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
559 560 561 562 563 564 565

        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])

        # scale dot product attention
K
kangguangli 已提交
566 567
        product = tensor.matmul(x=q, y=k, transpose_y=True)
        product = tensor.scale(product, scale=self.head_dim**-0.5)
568 569 570 571 572 573 574

        if self.attn_mask is not None:
            product = product + self.attn_mask

        weights = F.softmax(product)

        if self.dropout_ratio:
575 576 577 578 579 580
            weights = F.dropout(
                weights,
                self.dropout_ratio,
                training=self.training,
                mode="upscale_in_train",
            )
581 582 583 584 585 586 587 588 589 590

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

591
        if _global_parallel_strategy in ["mp", "dp_mp"]:
592 593 594 595 596
            auto.shard_tensor(
                self.out_proj.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
597 598 599 600 601

        # Add residual
        residual = embeddings + self.dropout2(out)

        # Pre-norm
602
        out0 = self.norm2(residual)
603 604 605 606 607 608

        # The following is the MLP part
        out1 = self.linear0(out0)
        out2 = F.gelu(out1, approximate=True)
        out3 = self.linear1(out2)

609
        if _global_parallel_strategy in ["mp", "dp_mp"]:
610 611 612 613 614 615 616 617 618 619
            auto.shard_tensor(
                self.linear0.weight,
                process_mesh=_global_process_mesh,
                shard_spec=[None, "mp"],
            )
            auto.shard_tensor(
                self.linear1.weight,
                process_mesh=_global_process_mesh,
                shard_spec=["mp", None],
            )
620 621 622 623 624 625 626

        # Add residual
        final = residual + self.dropout3(out3)
        return final


def decoder_pretrain_forward(train_program, start_program):
627 628 629
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
630 631 632
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
        input_ids = static.data(
            name="input_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        position_ids = static.data(
            name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        decoder = DecoderLayer(
            vocab_size=32768,
            hidden_size=hidden_size,
            sequence_len=sequence_len,
            max_position_embeddings=512,
            intermediate_size=4 * hidden_size,
            num_heads=16,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
649 650 651 652 653 654 655
        out = decoder(input_ids, position_ids)

    return train_program, start_program


class TestDecoderLayerAutoCompletion(unittest.TestCase):
    def test_decoder_dp(self):
656 657
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
658
        global _global_process_mesh
659 660 661
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["dp"]
        )
662 663 664
        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
665
        train_program, start_program = decoder_pretrain_forward(
666 667
            train_program, start_program
        )
668 669
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
670 671
            train_program
        )
672
        self.assertTrue(dist_context.validate_dist_attr_for_program())
673 674

    def test_decoder_mp(self):
675 676
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
677
        global _global_process_mesh
678 679 680
        _global_process_mesh = auto.ProcessMesh(
            mesh=[0, 1, 2, 3], dim_names=["mp"]
        )
681 682 683 684

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
685
        train_program, start_program = decoder_pretrain_forward(
686 687
            train_program, start_program
        )
688 689
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
690 691
            train_program
        )
692
        self.assertTrue(dist_context.validate_dist_attr_for_program())
693 694

    def test_decoder_dp_mp(self):
695 696
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
697
        global _global_process_mesh
698 699 700
        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["dp", "mp"]
        )
701 702 703 704

        train_program = static.Program()
        start_program = static.Program()
        dist_context = DistributedContext()
705
        train_program, start_program = decoder_pretrain_forward(
706 707
            train_program, start_program
        )
708 709
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
710 711
            train_program
        )
712
        self.assertTrue(dist_context.validate_dist_attr_for_program())
713 714 715 716


if __name__ == "__main__":
    unittest.main()