engine_api.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import time
17
import tempfile
18 19 20 21 22 23 24 25 26 27 28 29
import copy
import os
import numpy as np
import subprocess
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
30

31
import paddle.distributed.auto_parallel as auto
32 33
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
34 35 36 37 38 39 40 41 42 43 44 45 46 47

paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
PP_MESH_0 = auto.ProcessMesh([0])
PP_MESH_1 = auto.ProcessMesh([1])
batch_size = 1
batch_num = 10
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10

paddle.seed(44)

48 49
is_fetch = True

50 51

class MyDataset(Dataset):
52

53 54 55 56 57 58 59 60 61 62 63 64 65 66
    def __init__(self, num_samples):
        super(MyDataset, self).__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        input = np.random.uniform(size=image_size).astype("float32")
        label = np.random.randint(0, class_num - 1, dtype="int64")
        return input, label

    def __len__(self):
        return self.num_samples


class MLPLayer(nn.Layer):
67

68 69 70 71 72 73 74 75
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
76 77
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
78 79
        bias_attr = None

80 81 82 83 84 85 86 87
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
88 89 90 91 92
        self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
93
        out = auto.shard_op(self.norm, PP_MESH_0)(input)
94
        out = self.linear0(out)
95
        out = F.gelu(out, approximate=True)
96
        out = auto.shard_op(self.linear1, PP_MESH_1)(out)
97 98
        out = self.dropout(out)
        out = self.linear2(out)
99 100
        if is_fetch:
            auto.fetch(out, "out")
101 102 103
        return out


104
def train(fetch):
105 106
    global is_fetch
    is_fetch = fetch
107 108 109 110
    mlp = MLPLayer(hidden_size=hidden_size,
                   intermediate_size=4 * hidden_size,
                   dropout_ratio=0.1,
                   initializer_range=0.02)
111
    loss = paddle.nn.CrossEntropyLoss()
112
    optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
113 114 115 116
                                      beta1=0.9,
                                      beta2=0.999,
                                      epsilon=1e-08,
                                      grad_clip=None)
117
    metric = paddle.metric.Accuracy()
118

119 120
    strategy = auto.Strategy()
    strategy.auto_mode = "semi"
121

122
    engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
123

124 125
    # train
    train_dataset = MyDataset(batch_num * batch_size)
126 127 128
    eval_dataset1 = MyDataset(5 * batch_size)
    engine.fit(train_data=train_dataset,
               epochs=2,
129
               batch_size=batch_size,
130
               valid_data=eval_dataset1)
131

132
    # eval
133 134
    eval_dataset2 = MyDataset(batch_size)
    engine.evaluate(eval_dataset2, batch_size=batch_size)
135

136
    # predict
137
    test_dataset = MyDataset(batch_size)
138
    engine.predict(test_dataset, batch_size=batch_size)
139 140

    # save
141
    temp_dir = tempfile.TemporaryDirectory()
142 143 144
    model_filename = os.path.join(temp_dir.name, 'mlp')
    engine.save(model_filename, training=True)
    engine.load(model_filename)
145
    temp_dir.cleanup()
146 147 148


if __name__ == "__main__":
149 150
    train(fetch=True)
    train(fetch=False)