test_rule_based_tuner_o2.py 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest

import numpy as np

import paddle
from paddle import static

sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
    GPTForPretraining,
    GPTModel,
    GPTPretrainingCriterion,
)


def get_gpt_model(
    train_program, start_program, place, batch_size, sequence_len, vocab_size
):
    with static.program_guard(train_program, start_program):
        tokens = paddle.static.data(
            name="tokens", shape=[batch_size, sequence_len], dtype='int64'
        )
        position_ids = paddle.static.data(
            name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
        )
        attention_mask = paddle.static.data(
            name="attention_mask",
            shape=[batch_size, 1, sequence_len, sequence_len],
            dtype='float32',
        )
        labels = paddle.static.data(
            name="labels", shape=[batch_size, sequence_len], dtype='int64'
        )
        loss_mask = paddle.static.data(
            name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
        )

        gpt = GPTModel(
            vocab_size=1000,
            hidden_size=64,
            num_hidden_layers=2,
            num_attention_heads=8,
            intermediate_size=256,
            hidden_act="gelu",
            hidden_dropout_prob=0.0,
            attention_probs_dropout_prob=0.0,
            max_position_embeddings=1024,
            type_vocab_size=1,
            initializer_range=0.02,
            pad_token_id=0,
            eos_token_id=7,
            bos_token_id=0,
            eol_token_id=3,
        )

        model = GPTForPretraining(
            gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
        )
        preds = model(tokens, position_ids, attention_mask)
        criterion = GPTPretrainingCriterion()
        loss = criterion(preds, labels, loss_mask)

    def gen_data():
        np.random.seed(2021)
        tokens = []
        position_ids = []
        attention_mask = []
        labels = []
        loss_mask = []
        for _ in range(batch_size):
            tokens.append(np.random.randint(vocab_size, size=sequence_len))
            position_ids.append(np.arange(sequence_len))
            attention_mask.append([np.tril(np.ones(sequence_len))])
            labels.append(np.random.randint(vocab_size, size=sequence_len))
            loss_mask.append(np.ones(sequence_len))

        return tokens, position_ids, attention_mask, labels, loss_mask

    return train_program, start_program, loss, gen_data


class TestRuleBasedTuner(unittest.TestCase):
    def test_gpt_o2(self):
        modeling.init_global()
        train_program = static.Program()
        start_program = static.Program()
        batch_size = 8
        sequence_len = 512
        vocab_size = 1000
        place = None
        train_program, start_program, loss, gen_data = get_gpt_model(
            train_program,
            start_program,
            place,
            batch_size,
            sequence_len,
            vocab_size,
        )
115 116
        from paddle.distributed.auto_parallel.static.cluster import Cluster
        from paddle.distributed.auto_parallel.static.dist_context import (
117 118
            DistributedContext,
        )
119
        from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            RuleBasedTuner,
        )

        clip = paddle.nn.ClipGradByGlobalNorm(0.2)
        opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)

        cluster = Cluster()
        cluster.gen_default_config_cluster(node_count=1, device_count=8)
        dist_context = DistributedContext(
            serial_main_prog=train_program,
            serial_startup_prog=start_program,
            serial_optimizer=opt,
            serial_loss=loss,
            cluster=cluster,
        )
        dist_context.initialize()
        tuner = RuleBasedTuner(dist_context, level="o2")
        tuner.tune()


if __name__ == "__main__":
    unittest.main()