engine_api.py 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import tempfile
16 17 18
import os
import numpy as np
import paddle
19 20
import paddle.static as static
import paddle.utils as utils
21 22
import paddle.nn as nn
import paddle.nn.functional as F
23
from paddle.io import Dataset
24

25
from paddle.distributed.fleet import auto
26 27 28 29 30

paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
PP_MESH_0 = auto.ProcessMesh([0])
PP_MESH_1 = auto.ProcessMesh([1])
31 32
epoch_num = 1
batch_size = 2
33 34 35 36 37 38 39 40
batch_num = 10
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10

paddle.seed(44)

41
is_fetch = True
42 43
is_feed = True
my_feed_vars = []
44

45 46

class MyDataset(Dataset):
47

48 49 50 51 52 53 54 55 56 57 58 59 60
    def __init__(self, num_samples):
        super(MyDataset, self).__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        input = np.random.uniform(size=image_size).astype("float32")
        label = np.random.randint(0, class_num - 1, dtype="int64")
        return input, label

    def __len__(self):
        return self.num_samples


61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
def get_random_inputs_and_labels(image_shape, label_shape):
    input = np.random.random(size=image_shape).astype('float32')
    label = np.random.random(size=label_shape).astype('int64')
    return input, label


def batch_generator_creator():

    def __reader__():
        for _ in range(batch_num):
            batch_input, batch_label = get_random_inputs_and_labels(
                [batch_size, image_size], [batch_size, 1])
            yield batch_input, batch_label

    return __reader__


78
class MLPLayer(nn.Layer):
79

80 81 82 83 84 85 86 87
    def __init__(self,
                 hidden_size=1024,
                 intermediate_size=4 * 1024,
                 dropout_ratio=0.1,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
88 89
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
90 91
        bias_attr = None

92 93 94 95 96 97 98 99
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
100 101 102 103 104
        self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
105
        out = auto.shard_op(self.norm, PP_MESH_0)(input)
106
        out = self.linear0(out)
107 108
        if is_feed:
            my_feed_vars.append((out, out.shape))
109
        out = F.gelu(out, approximate=True)
110
        out = auto.shard_op(self.linear1, PP_MESH_1)(out)
111 112
        out = self.dropout(out)
        out = self.linear2(out)
113 114
        if is_feed:
            my_feed_vars.append((out, out.shape))
115
        if is_fetch:
116
            auto.fetch(out, "my_out", logging=True)
117 118 119
        return out


120
def train_high_level(fetch):
121 122
    global is_fetch
    is_fetch = fetch
123 124 125 126
    mlp = MLPLayer(hidden_size=hidden_size,
                   intermediate_size=4 * hidden_size,
                   dropout_ratio=0.1,
                   initializer_range=0.02)
127
    loss = paddle.nn.CrossEntropyLoss()
128
    optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
129 130 131 132
                                      beta1=0.9,
                                      beta2=0.999,
                                      epsilon=1e-08,
                                      grad_clip=None)
133
    metric = paddle.metric.Accuracy()
134

135 136
    strategy = auto.Strategy()
    strategy.auto_mode = "semi"
137

138
    engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
139

140 141
    # train
    train_dataset = MyDataset(batch_num * batch_size)
142 143 144
    eval_dataset1 = MyDataset(5 * batch_size)
    engine.fit(train_data=train_dataset,
               epochs=2,
145
               batch_size=batch_size,
146
               valid_data=eval_dataset1)
147

148
    # eval
149 150
    eval_dataset2 = MyDataset(batch_size)
    engine.evaluate(eval_dataset2, batch_size=batch_size)
151

152
    # predict
153
    test_dataset = MyDataset(batch_size)
154
    engine.predict(test_dataset, batch_size=batch_size)
155 156

    # save
157
    temp_dir = tempfile.TemporaryDirectory()
158 159 160
    model_filename = os.path.join(temp_dir.name, 'mlp')
    engine.save(model_filename, training=True)
    engine.load(model_filename)
161
    temp_dir.cleanup()
162 163


164
def train_low_level():
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    mlp = MLPLayer(hidden_size=hidden_size,
                   intermediate_size=4 * hidden_size,
                   dropout_ratio=0.1,
                   initializer_range=0.02)
    loss = paddle.nn.CrossEntropyLoss()
    optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
                                      beta1=0.9,
                                      beta2=0.999,
                                      epsilon=1e-08,
                                      grad_clip=None)
    metric = paddle.metric.Accuracy()

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

180
    engine = auto.Engine(mlp, loss, optimizer, metrics=None, strategy=strategy)
181

182 183 184 185 186
    feed_dict = {}
    for feed_var, shape in my_feed_vars:
        feed_dict[feed_var.name] = np.zeros(shape, dtype="float32")

    # Build normal normal dataloader
187 188 189 190 191
    # train
    train_dataset = MyDataset(batch_num * batch_size)
    train_dataloader = engine.dataloader(train_dataset,
                                         batch_size=batch_size,
                                         mode="train")
192 193 194
    engine.prepare(mode="train")
    for data in train_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="train")
195 196 197 198 199 200

    # eval
    eval_dataset2 = MyDataset(batch_size)
    eval_dataloader = engine.dataloader(eval_dataset2,
                                        batch_size=batch_size,
                                        mode="eval")
201 202 203
    engine.prepare(mode="eval")
    for data in eval_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="eval")
204 205

    # predict
206
    engine.to_mode("predict")
207
    test_dataset = MyDataset(batch_size)
208 209 210 211
    predict_dataloader = engine.dataloader(test_dataset, batch_size=batch_size)
    engine.prepare()
    for data in predict_dataloader:
        outs = engine.run(data, feed=feed_dict)
212 213 214 215 216 217 218 219

    # save
    temp_dir = tempfile.TemporaryDirectory()
    model_filename = os.path.join(temp_dir.name, 'mlp')
    engine.save(model_filename, training=True)
    engine.load(model_filename)
    temp_dir.cleanup()

220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
    # Build dataloader from generator
    # train
    train_dataset = MyDataset(batch_num * batch_size)
    train_dataloader = engine.dataloader_from_generator(train_dataset,
                                                        batch_size=batch_size,
                                                        mode="train")
    engine.prepare(mode="train")
    for data in train_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="train")

    # eval
    engine.to_mode("eval")
    eval_dataset2 = MyDataset(batch_size)
    eval_dataloader = engine.dataloader_from_generator(eval_dataset2,
                                                       batch_size=batch_size)
    engine.prepare()
    for data in eval_dataloader:
        outs = engine.run(data, feed=feed_dict)

    # predict
    test_dataset = MyDataset(batch_size)
    predict_dataloader = engine.dataloader_from_generator(test_dataset,
                                                          batch_size=batch_size,
                                                          mode="predict")
    engine.prepare(mode="predict")
    for data in predict_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="predict")

    # save
    temp_dir = tempfile.TemporaryDirectory()
    model_filename = os.path.join(temp_dir.name, 'mlp')
    engine.save(model_filename, training=True)
    engine.load(model_filename)
    temp_dir.cleanup()


def train_builtin_data_vars():
    mlp = MLPLayer(hidden_size=hidden_size,
                   intermediate_size=4 * hidden_size,
                   dropout_ratio=0.1,
                   initializer_range=0.02)
    loss = paddle.nn.CrossEntropyLoss()
    optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
                                      beta1=0.9,
                                      beta2=0.999,
                                      epsilon=1e-08,
                                      grad_clip=None)
    metric = paddle.metric.Accuracy()

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

    engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)

    # train
    engine.to_mode("train")

    input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
    label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
    engine.prepare(inputs_spec=[input_spec], labels_spec=[label_spec])

    with static.program_guard(engine.main_program, engine.startup_program):
        feed_list = engine.inputs + engine.labels
        print(feed_list)
        loader = paddle.io.DataLoader.from_generator(feed_list=feed_list,
                                                     capacity=4 * batch_size,
                                                     iterable=False)

        places = static.cuda_places()
        loader.set_batch_generator(batch_generator_creator(), places=places)

    for _ in range(epoch_num):
        loader.start()  # call DataLoader.start() before each epoch starts
        try:
            while True:
                engine.run()
        except paddle.fluid.core.EOFException:
            loader.reset(
            )  # call DataLoader.reset() after catching EOFException


def train_non_builtin_data_vars():
    main_program = static.Program()
    startup_program = static.Program()
    with static.program_guard(main_program,
                              startup_program), utils.unique_name.guard():
        input = static.data(name="input",
                            shape=[batch_size, image_size],
                            dtype='float32')
        label = static.data(name="label", shape=[batch_size, 1], dtype='int64')

        loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
                                                     capacity=4 * batch_size,
                                                     iterable=False)
        places = static.cuda_places()
        loader.set_batch_generator(batch_generator_creator(), places=places)

        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       dropout_ratio=0.1,
                       initializer_range=0.02)
        loss = paddle.nn.CrossEntropyLoss()
        optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
                                          beta1=0.9,
                                          beta2=0.999,
                                          epsilon=1e-08,
                                          grad_clip=None)
        metric = paddle.metric.Accuracy()
        predict = mlp(input)
        loss_var = loss(predict, label)

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

    engine = auto.Engine(loss=loss_var,
                         optimizer=optimizer,
                         metrics=metric,
                         strategy=strategy)

    # train
    engine.to_mode("train")
    engine.prepare(inputs=[input],
                   labels=[label],
                   main_program=main_program,
                   startup_program=startup_program)
    for _ in range(epoch_num):
        loader.start()  # call DataLoader.start() before each epoch starts
        try:
            while True:
                engine.run()
        except paddle.fluid.core.EOFException:
            loader.reset(
            )  # call DataLoader.reset() after catching EOFException

354

355
if __name__ == "__main__":
356 357 358 359 360
    train_high_level(fetch=True)
    train_high_level(fetch=False)
    train_low_level()
    train_builtin_data_vars()
    train_non_builtin_data_vars()