test_resnet_amp.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import unittest

import numpy as np
19
from test_resnet import SEED, ResNet, optimizer_setting
20 21

import paddle
22
from paddle import fluid
23
from paddle.fluid import core
24 25 26 27

# NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout.
batch_size = 2
epoch_num = 1
28 29 30
place = (
    fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
)
31 32 33 34 35 36 37 38


if fluid.is_compiled_with_cuda():
    fluid.set_flags({'FLAGS_cudnn_deterministic': True})


def train(to_static, build_strategy=None):
    """
39
    Tests model decorated by `dygraph_to_static_output` in static graph mode. For users, the model is defined in dygraph mode and trained in static graph mode.
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    """
    with fluid.dygraph.guard(place):
        np.random.seed(SEED)
        paddle.seed(SEED)
        paddle.framework.random._manual_program_seed(SEED)

        resnet = ResNet()
        if to_static:
            resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
        optimizer = optimizer_setting(parameter_list=resnet.parameters())
        scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

        for epoch in range(epoch_num):
            total_loss = 0.0
            total_acc1 = 0.0
            total_acc5 = 0.0
            total_sample = 0

            for batch_id in range(100):
                start_time = time.time()
                img = paddle.to_tensor(
61 62 63 64
                    np.random.random([batch_size, 3, 224, 224]).astype(
                        'float32'
                    )
                )
65
                label = paddle.to_tensor(
66 67
                    np.random.randint(0, 100, [batch_size, 1], dtype='int64')
                )
68 69 70 71 72 73 74 75
                img.stop_gradient = True
                label.stop_gradient = True

                with paddle.amp.auto_cast():
                    pred = resnet(img)
                    # FIXME(Aurelius84): The followding cross_entropy seems to bring out a
                    # precision problem, need to figure out the underlying reason.
                    # If we remove it, the loss between dygraph and dy2stat is exactly same.
76 77 78 79 80 81
                    loss = paddle.nn.functional.cross_entropy(
                        input=pred,
                        label=label,
                        reduction='none',
                        use_softmax=False,
                    )
82
                avg_loss = paddle.mean(x=pred)
83 84
                acc_top1 = paddle.static.accuracy(input=pred, label=label, k=1)
                acc_top5 = paddle.static.accuracy(input=pred, label=label, k=5)
85 86 87 88 89 90 91 92 93 94 95 96 97

                scaled = scaler.scale(avg_loss)
                scaled.backward()
                scaler.minimize(optimizer, scaled)
                resnet.clear_gradients()

                total_loss += avg_loss
                total_acc1 += acc_top1
                total_acc5 += acc_top5
                total_sample += 1

                end_time = time.time()
                if batch_id % 2 == 0:
98 99 100 101 102 103 104 105 106 107 108
                    print(
                        "epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f"
                        % (
                            epoch,
                            batch_id,
                            total_loss.numpy() / total_sample,
                            total_acc1.numpy() / total_sample,
                            total_acc5.numpy() / total_sample,
                            end_time - start_time,
                        )
                    )
109 110 111 112 113 114 115 116
                if batch_id == 10:
                    break

    return total_loss.numpy()


class TestResnet(unittest.TestCase):
    def train(self, to_static):
R
Ryan 已提交
117
        paddle.jit.enable_to_static(to_static)
118 119 120 121 122
        return train(to_static)

    def test_resnet(self):
        static_loss = self.train(to_static=True)
        dygraph_loss = self.train(to_static=False)
123 124 125 126 127
        np.testing.assert_allclose(
            static_loss,
            dygraph_loss,
            rtol=1e-05,
            err_msg='static_loss: {} \n dygraph_loss: {}'.format(
128 129 130
                static_loss, dygraph_loss
            ),
        )
131

132
    def test_resnet_composite(self):
133
        core._set_prim_backward_enabled(True)
134
        static_loss = self.train(to_static=True)
135
        core._set_prim_backward_enabled(False)
136 137 138 139 140 141 142 143 144 145
        dygraph_loss = self.train(to_static=False)
        np.testing.assert_allclose(
            static_loss,
            dygraph_loss,
            rtol=1e-05,
            err_msg='static_loss: {} \n dygraph_loss: {}'.format(
                static_loss, dygraph_loss
            ),
        )

146 147

if __name__ == '__main__':
148
    unittest.main()