test_mnist_pure_fp16.py 4.1 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from time import time
17 18

import numpy as np
19
from test_mnist import MNIST, SEED, TestMNIST
20

21 22
import paddle

23 24 25 26 27 28 29 30 31 32 33 34
if paddle.fluid.is_compiled_with_cuda():
    paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})


class TestPureFP16(TestMNIST):
    def train_static(self):
        return self.train(to_static=True)

    def train_dygraph(self):
        return self.train(to_static=False)

    def test_mnist_to_static(self):
35
        if paddle.fluid.is_compiled_with_cuda():
36 37 38
            dygraph_loss = self.train_dygraph()
            static_loss = self.train_static()
            # NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
39 40 41 42 43 44
            np.testing.assert_allclose(
                dygraph_loss,
                static_loss,
                rtol=1e-05,
                atol=0.001,
                err_msg='dygraph is {}\n static_res is \n{}'.format(
45 46 47
                    dygraph_loss, static_loss
                ),
            )
48 49 50 51 52 53 54 55 56 57

    def train(self, to_static=False):
        np.random.seed(SEED)
        paddle.seed(SEED)
        paddle.framework.random._manual_program_seed(SEED)

        mnist = MNIST()

        if to_static:
            print("Successfully to apply @to_static.")
58 59 60 61 62
            build_strategy = paddle.static.BuildStrategy()
            # Why set `build_strategy.enable_inplace = False` here?
            # Because we find that this PASS strategy of PE makes dy2st training loss unstable.
            build_strategy.enable_inplace = False
            mnist = paddle.jit.to_static(mnist, build_strategy=build_strategy)
63

64 65 66
        optimizer = paddle.optimizer.Adam(
            learning_rate=0.001, parameters=mnist.parameters()
        )
67 68 69

        scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

70 71 72
        mnist, optimizer = paddle.amp.decorate(
            models=mnist, optimizers=optimizer, level='O2', save_dtype='float32'
        )
73 74 75 76 77

        loss_data = []
        for epoch in range(self.epoch_num):
            start = time()
            for batch_id, data in enumerate(self.train_reader()):
78 79 80 81 82 83 84 85
                dy_x_data = np.array(
                    [x[0].reshape(1, 28, 28) for x in data]
                ).astype('float32')
                y_data = (
                    np.array([x[1] for x in data])
                    .astype('int64')
                    .reshape(-1, 1)
                )
86 87 88 89 90

                img = paddle.to_tensor(dy_x_data)
                label = paddle.to_tensor(y_data)
                label.stop_gradient = True

91 92 93 94 95 96
                with paddle.amp.auto_cast(
                    enable=True,
                    custom_white_list=None,
                    custom_black_list=None,
                    level='O2',
                ):
97 98 99 100 101 102 103 104 105
                    prediction, acc, avg_loss = mnist(img, label=label)

                scaled = scaler.scale(avg_loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)

                loss_data.append(avg_loss.numpy()[0])
                # save checkpoint
                mnist.clear_gradients()
0
0x45f 已提交
106
                if batch_id % 2 == 0:
107
                    print(
108 109 110 111 112 113 114 115
                        "Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}".format(
                            epoch,
                            batch_id,
                            avg_loss.numpy(),
                            acc.numpy(),
                            time() - start,
                        )
                    )
116
                    start = time()
0
0x45f 已提交
117
                if batch_id == 10:
118 119 120 121 122
                    break
        return loss_data


if __name__ == '__main__':
123
    unittest.main()