test_mnist_amp.py 3.6 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from time import time
17 18

import numpy as np
19
from dygraph_to_static_util import test_and_compare_with_new_ir
20 21 22
from test_mnist import MNIST, SEED, TestMNIST

import paddle
23 24 25 26 27 28 29 30 31 32 33 34 35
from paddle.fluid.optimizer import AdamOptimizer

if paddle.fluid.is_compiled_with_cuda():
    paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})


class TestAMP(TestMNIST):
    def train_static(self):
        return self.train(to_static=True)

    def train_dygraph(self):
        return self.train(to_static=False)

36
    @test_and_compare_with_new_ir(False)
37 38 39 40 41 42
    def test_mnist_to_static(self):
        dygraph_loss = self.train_dygraph()
        static_loss = self.train_static()
        # NOTE(Aurelius84): In static AMP training, there is a grep_list but
        # dygraph AMP don't. It will bring the numbers of cast_op is different
        # and leads to loss has a bit diff.
43 44 45 46 47 48
        np.testing.assert_allclose(
            dygraph_loss,
            static_loss,
            rtol=1e-05,
            atol=0.001,
            err_msg='dygraph is {}\n static_res is \n{}'.format(
49 50 51
                dygraph_loss, static_loss
            ),
        )
52 53 54 55 56 57 58 59 60

    def train(self, to_static=False):
        paddle.seed(SEED)
        mnist = MNIST()

        if to_static:
            print("Successfully to apply @to_static.")
            mnist = paddle.jit.to_static(mnist)

61 62 63
        adam = AdamOptimizer(
            learning_rate=0.001, parameter_list=mnist.parameters()
        )
64 65 66 67 68 69 70

        scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        loss_data = []
        for epoch in range(self.epoch_num):
            start = time()
            for batch_id, data in enumerate(self.train_reader()):
71 72 73 74 75 76 77 78
                dy_x_data = np.array(
                    [x[0].reshape(1, 28, 28) for x in data]
                ).astype('float32')
                y_data = (
                    np.array([x[1] for x in data])
                    .astype('int64')
                    .reshape(-1, 1)
                )
79 80 81 82 83 84 85 86 87 88 89 90

                img = paddle.to_tensor(dy_x_data)
                label = paddle.to_tensor(y_data)
                label.stop_gradient = True

                with paddle.amp.auto_cast():
                    prediction, acc, avg_loss = mnist(img, label=label)

                scaled = scaler.scale(avg_loss)
                scaled.backward()
                scaler.minimize(adam, scaled)

91
                loss_data.append(float(avg_loss))
92 93 94 95
                # save checkpoint
                mnist.clear_gradients()
                if batch_id % 10 == 0:
                    print(
96 97 98 99 100 101 102 103
                        "Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}".format(
                            epoch,
                            batch_id,
                            avg_loss.numpy(),
                            acc.numpy(),
                            time() - start,
                        )
                    )
104 105 106 107 108 109 110
                    start = time()
                if batch_id == 50:
                    break
        return loss_data


if __name__ == '__main__':
111
    unittest.main()