engine_api.py 15.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import time
17
import tempfile
18 19 20 21 22
import copy
import os
import numpy as np
import subprocess
import paddle
23 24
import paddle.static as static
import paddle.utils as utils
25 26 27 28 29 30 31
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
32

33
from paddle.distributed.fleet import auto
Z
zhaoyingli 已提交
34 35 36 37
from paddle.distributed.auto_parallel.interface import (
    get_collection,
    CollectionNames,
)
38 39
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
40 41

paddle.enable_static()
42

43 44 45
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
PP_MESH_0 = auto.ProcessMesh([0])
PP_MESH_1 = auto.ProcessMesh([1])
46 47
epoch_num = 1
batch_size = 2
48 49 50 51 52 53 54 55
batch_num = 10
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10

paddle.seed(44)

56
is_fetch = True
57 58
is_feed = True
my_feed_vars = []
59

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

class MyDataset(Dataset):
    def __init__(self, num_samples):
        super(MyDataset, self).__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        input = np.random.uniform(size=image_size).astype("float32")
        label = np.random.randint(0, class_num - 1, dtype="int64")
        return input, label

    def __len__(self):
        return self.num_samples


75 76 77 78 79 80 81 82 83 84
def get_random_inputs_and_labels(image_shape, label_shape):
    input = np.random.random(size=image_shape).astype('float32')
    label = np.random.random(size=label_shape).astype('int64')
    return input, label


def batch_generator_creator():
    def __reader__():
        for _ in range(batch_num):
            batch_input, batch_label = get_random_inputs_and_labels(
Z
zhaoyingli 已提交
85 86
                [batch_size, image_size], [batch_size, 1]
            )
87 88 89 90 91
            yield batch_input, batch_label

    return __reader__


92
class MLPLayer(nn.Layer):
Z
zhaoyingli 已提交
93 94 95 96 97 98 99
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
100 101 102
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
103
        weight_attr = paddle.ParamAttr(
Z
zhaoyingli 已提交
104 105
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
106 107
        bias_attr = None

Z
zhaoyingli 已提交
108 109 110 111 112 113
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
114 115 116 117 118
        self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

    def forward(self, input):
119
        out = auto.shard_op(self.norm, PP_MESH_0)(input)
120
        out = self.linear0(out)
121 122
        if is_feed:
            my_feed_vars.append((out, out.shape))
123
        out = F.gelu(out, approximate=True)
124
        out = auto.shard_op(self.linear1, PP_MESH_1)(out)
125 126
        out = self.dropout(out)
        out = self.linear2(out)
127 128
        if is_feed:
            my_feed_vars.append((out, out.shape))
129
        if is_fetch:
130
            auto.fetch(out, "my_fetch", logging=True)
131 132 133
        return out


134
def train_high_level(fetch):
135 136
    global is_fetch
    is_fetch = fetch
Z
zhaoyingli 已提交
137 138 139 140 141 142
    mlp = MLPLayer(
        hidden_size=hidden_size,
        intermediate_size=4 * hidden_size,
        dropout_ratio=0.1,
        initializer_range=0.02,
    )
143
    loss = paddle.nn.CrossEntropyLoss()
Z
zhaoyingli 已提交
144 145 146 147 148 149 150
    optimizer = paddle.optimizer.Adam(
        learning_rate=0.00001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-08,
        grad_clip=None,
    )
151
    metric = paddle.metric.Accuracy()
152

153 154
    strategy = auto.Strategy()
    strategy.auto_mode = "semi"
155

156
    engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
157

158 159
    # train
    train_dataset = MyDataset(batch_num * batch_size)
160
    eval_dataset1 = MyDataset(5 * batch_size)
161

Z
zhaoyingli 已提交
162 163 164 165 166 167 168
    history = engine.fit(
        train_data=train_dataset,
        epochs=2,
        batch_size=batch_size,
        valid_data=eval_dataset1,
        log_freq=1,
    )
169

170
    # eval
171 172
    eval_dataset2 = MyDataset(batch_size)
    engine.evaluate(eval_dataset2, batch_size=batch_size)
173

174
    # predict
175
    test_dataset = MyDataset(batch_size)
176
    outputs = engine.predict(test_dataset, batch_size=batch_size)
177 178

    # save
179
    temp_dir = tempfile.TemporaryDirectory()
180 181 182
    model_filename = os.path.join(temp_dir.name, 'mlp')
    engine.save(model_filename, training=True)
    engine.load(model_filename)
183
    temp_dir.cleanup()
184 185


186
def train_low_level():
Z
zhaoyingli 已提交
187 188 189 190 191 192
    mlp = MLPLayer(
        hidden_size=hidden_size,
        intermediate_size=4 * hidden_size,
        dropout_ratio=0.1,
        initializer_range=0.02,
    )
193
    loss = paddle.nn.CrossEntropyLoss()
Z
zhaoyingli 已提交
194 195 196 197 198 199 200
    optimizer = paddle.optimizer.Adam(
        learning_rate=0.00001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-08,
        grad_clip=None,
    )
201 202 203 204 205 206 207 208 209 210 211 212 213 214
    metric = paddle.metric.Accuracy()

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

    engine = auto.Engine(mlp, loss, optimizer, metrics=None, strategy=strategy)

    feed_dict = {}
    for feed_var, shape in my_feed_vars:
        feed_dict[feed_var.name] = np.zeros(shape, dtype="float32")

    # Build normal normal dataloader
    # train
    train_dataset = MyDataset(batch_num * batch_size)
Z
zhaoyingli 已提交
215 216 217
    train_dataloader = engine.dataloader(
        train_dataset, batch_size=batch_size, mode="train"
    )
218 219 220 221 222 223
    engine.prepare(mode="train")
    for data in train_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="train")

    # eval
    eval_dataset2 = MyDataset(batch_size)
Z
zhaoyingli 已提交
224 225 226
    eval_dataloader = engine.dataloader(
        eval_dataset2, batch_size=batch_size, mode="eval"
    )
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
    engine.prepare(mode="eval")
    for data in eval_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="eval")

    # predict
    engine.to_mode("predict")
    test_dataset = MyDataset(batch_size)
    predict_dataloader = engine.dataloader(test_dataset, batch_size=batch_size)
    engine.prepare()
    for data in predict_dataloader:
        outs = engine.run(data, feed=feed_dict)

    # save
    temp_dir = tempfile.TemporaryDirectory()
    model_filename = os.path.join(temp_dir.name, 'mlp')
    engine.save(model_filename, training=True)
    engine.load(model_filename)
    temp_dir.cleanup()

    # Build dataloader from generator
    # train
    train_dataset = MyDataset(batch_num * batch_size)
Z
zhaoyingli 已提交
249 250 251
    train_dataloader = engine.dataloader_from_generator(
        train_dataset, batch_size=batch_size, mode="train"
    )
252 253 254 255 256 257 258
    engine.prepare(mode="train")
    for data in train_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="train")

    # eval
    engine.to_mode("eval")
    eval_dataset2 = MyDataset(batch_size)
Z
zhaoyingli 已提交
259 260 261
    eval_dataloader = engine.dataloader_from_generator(
        eval_dataset2, batch_size=batch_size
    )
262 263 264 265 266 267
    engine.prepare()
    for data in eval_dataloader:
        outs = engine.run(data, feed=feed_dict)

    # predict
    test_dataset = MyDataset(batch_size)
Z
zhaoyingli 已提交
268 269 270
    predict_dataloader = engine.dataloader_from_generator(
        test_dataset, batch_size=batch_size, mode="predict"
    )
271 272 273 274 275 276 277 278 279 280 281 282 283
    engine.prepare(mode="predict")
    for data in predict_dataloader:
        outs = engine.run(data, feed=feed_dict, mode="predict")

    # save
    temp_dir = tempfile.TemporaryDirectory()
    model_filename = os.path.join(temp_dir.name, 'mlp')
    engine.save(model_filename, training=True)
    engine.load(model_filename)
    temp_dir.cleanup()


def train_builtin_data_vars():
Z
zhaoyingli 已提交
284 285 286 287 288 289
    mlp = MLPLayer(
        hidden_size=hidden_size,
        intermediate_size=4 * hidden_size,
        dropout_ratio=0.1,
        initializer_range=0.02,
    )
290
    loss = paddle.nn.CrossEntropyLoss()
Z
zhaoyingli 已提交
291 292 293 294 295 296 297
    optimizer = paddle.optimizer.Adam(
        learning_rate=0.00001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-08,
        grad_clip=None,
    )
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
    metric = paddle.metric.Accuracy()

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

    engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)

    # train
    engine.to_mode("train")

    input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
    label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
    engine.prepare(inputs_spec=[input_spec], labels_spec=[label_spec])

    with static.program_guard(engine.main_program, engine.startup_program):
        feed_list = engine.inputs + engine.labels
        print(feed_list)
Z
zhaoyingli 已提交
315 316 317
        loader = paddle.io.DataLoader.from_generator(
            feed_list=feed_list, capacity=4 * batch_size, iterable=False
        )
318 319 320 321 322 323 324 325 326 327

        places = static.cuda_places()
        loader.set_batch_generator(batch_generator_creator(), places=places)

    for _ in range(epoch_num):
        loader.start()  # call DataLoader.start() before each epoch starts
        try:
            while True:
                engine.run()
        except paddle.fluid.core.EOFException:
Z
zhaoyingli 已提交
328
            loader.reset()  # call DataLoader.reset() after catching EOFException
329 330 331 332 333


def train_non_builtin_data_vars():
    main_program = static.Program()
    startup_program = static.Program()
Z
zhaoyingli 已提交
334 335 336 337 338 339
    with static.program_guard(
        main_program, startup_program
    ), utils.unique_name.guard():
        input = static.data(
            name="input", shape=[batch_size, image_size], dtype='float32'
        )
340 341
        label = static.data(name="label", shape=[batch_size, 1], dtype='int64')

Z
zhaoyingli 已提交
342 343 344
        loader = paddle.io.DataLoader.from_generator(
            feed_list=[input, label], capacity=4 * batch_size, iterable=False
        )
345 346 347
        places = static.cuda_places()
        loader.set_batch_generator(batch_generator_creator(), places=places)

Z
zhaoyingli 已提交
348 349 350 351 352 353
        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
354
        loss = paddle.nn.CrossEntropyLoss()
Z
zhaoyingli 已提交
355 356 357 358 359 360 361
        optimizer = paddle.optimizer.Adam(
            learning_rate=0.00001,
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-08,
            grad_clip=None,
        )
362 363 364 365 366 367 368
        metric = paddle.metric.Accuracy()
        predict = mlp(input)
        loss_var = loss(predict, label)

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

Z
zhaoyingli 已提交
369 370 371
    engine = auto.Engine(
        loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
    )
372 373 374

    # train
    engine.to_mode("train")
Z
zhaoyingli 已提交
375 376 377 378 379 380
    engine.prepare(
        inputs=[input],
        labels=[label],
        main_program=main_program,
        startup_program=startup_program,
    )
381 382 383 384 385 386
    for _ in range(epoch_num):
        loader.start()  # call DataLoader.start() before each epoch starts
        try:
            while True:
                engine.run()
        except paddle.fluid.core.EOFException:
Z
zhaoyingli 已提交
387
            loader.reset()  # call DataLoader.reset() after catching EOFException
388 389 390


def get_cost():
Z
zhaoyingli 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
    main_program = static.Program()
    startup_program = static.Program()
    with static.program_guard(
        main_program, startup_program
    ), utils.unique_name.guard():
        input = static.data(
            name="input", shape=[batch_size, image_size], dtype='float32'
        )
        label = static.data(name="label", shape=[batch_size, 1], dtype='int64')

        loader = paddle.io.DataLoader.from_generator(
            feed_list=[input, label], capacity=4 * batch_size, iterable=False
        )
        places = static.cuda_places()
        loader.set_batch_generator(batch_generator_creator(), places=places)

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
        loss = paddle.nn.CrossEntropyLoss()
        optimizer = paddle.optimizer.Adam(
            learning_rate=0.00001,
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-08,
            grad_clip=None,
        )
        metric = paddle.metric.Accuracy()
        predict = mlp(input)
        loss_var = loss(predict, label)

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

    engine = auto.Engine(
        loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
    )
    engine.prepare(
        main_program=main_program,
        startup_program=startup_program,
        inputs=[input],
        labels=[label],
        mode="train",
    )
    engine.cost()


def get_cost_by_default_program():
442 443
    main_program = static.default_main_program()
    startup_program = static.default_startup_program()
Z
zhaoyingli 已提交
444 445 446 447 448 449
    with static.program_guard(
        main_program, startup_program
    ), utils.unique_name.guard():
        input = static.data(
            name="input", shape=[batch_size, image_size], dtype='float32'
        )
450 451
        label = static.data(name="label", shape=[batch_size, 1], dtype='int64')

Z
zhaoyingli 已提交
452 453 454
        loader = paddle.io.DataLoader.from_generator(
            feed_list=[input, label], capacity=4 * batch_size, iterable=False
        )
455 456 457
        places = static.cuda_places()
        loader.set_batch_generator(batch_generator_creator(), places=places)

Z
zhaoyingli 已提交
458 459 460 461 462 463
        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
464
        loss = paddle.nn.CrossEntropyLoss()
Z
zhaoyingli 已提交
465 466 467 468 469 470 471
        optimizer = paddle.optimizer.Adam(
            learning_rate=0.00001,
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-08,
            grad_clip=None,
        )
472 473 474 475 476 477 478
        metric = paddle.metric.Accuracy()
        predict = mlp(input)
        loss_var = loss(predict, label)

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

Z
zhaoyingli 已提交
479 480 481 482
    engine = auto.Engine(
        loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
    )
    engine.cost(mode="train")
483 484 485


def get_cost_by_spec():
Z
zhaoyingli 已提交
486 487 488 489 490 491
    mlp = MLPLayer(
        hidden_size=hidden_size,
        intermediate_size=4 * hidden_size,
        dropout_ratio=0.1,
        initializer_range=0.02,
    )
492
    loss = paddle.nn.CrossEntropyLoss()
Z
zhaoyingli 已提交
493 494 495 496 497 498 499
    optimizer = paddle.optimizer.Adam(
        learning_rate=0.00001,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-08,
        grad_clip=None,
    )
500 501 502 503 504 505 506 507 508 509 510 511
    metric = paddle.metric.Accuracy()

    strategy = auto.Strategy()
    strategy.auto_mode = "semi"

    engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)

    input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
    label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
    engine.cost(mode="eval", inputs_spec=[input_spec], labels_spec=[label_spec])


512
if __name__ == "__main__":
513 514 515 516 517 518
    train_high_level(fetch=True)
    train_high_level(fetch=False)
    train_low_level()
    train_builtin_data_vars()
    train_non_builtin_data_vars()
    get_cost()
Z
zhaoyingli 已提交
519
    get_cost_by_default_program()
520
    get_cost_by_spec()